{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font size=6>卷积神经网络</font>\n",
    "# 计算机视觉-computer version\n",
    "## 场景\n",
    "<table>\n",
    "    <tr><th>图像识别</th><th>目标检测</th><th>风格迁移</th></tr>\n",
    "    <tr>\n",
    "    <td><img src=\"images/373615de4e30035c662958ce39115fb4.png\"></td>\n",
    "    <td><img src=\"images/f8ff84bc95636d9e37e35daef5149164.png\"></td>\n",
    "    <td><img src=\"images/bf57536975bce32f78c9e66a2360e8a1.png\"></td>\n",
    "    </tr></table>\n",
    "    \n",
    "## 传统神经网络在计算机视觉领域遇到的问题\n",
    "$\\large (m,1000,1000,3)\\to 特征向量300万维\\xrightarrow{隐含层1000个单元}隐含层参数W^{[1]}是[e^3,3e^6]的矩阵，30亿个参数$\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 特征提取-feature、边缘检测-edge detection、卷积运算-Convolution \n",
    "1. 抽象特征$\\to$具体特征\n",
    "2. 微小特征$\\to$局部特征$\\to$整体特征\n",
    "<img src=\"images/59e76bb2-2a99-4c8f-af79-a5249d8eac41.png\">\n",
    "\n",
    "## 水平边缘和垂直边缘\n",
    "<img src=\"images/47c14f666d56e509a6863e826502bda2.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 卷积运算\n",
    "看一个例子，这是一个6×6的灰度图像。因为是灰度图像，所以它是6×6×1的矩阵，而不是6×6×3的，因为没有**RGB**三通道。为了检测图像中的垂直边缘，你可以构造一个3×3矩阵$\\begin{bmatrix}1 & 0 & -1\\\\ 1 & 0 & -1\\\\ 1 & 0 & -1\\end{bmatrix}$（在卷积神经网络的术语中，它被称为过滤器-filter），来对图像进行卷积运算。\n",
    "\n",
    "<img src=\"images/7099a5373f2281626aa8ddd47a180571.png\">\n",
    "\n",
    "这个卷积运算的输出将会是一个4×4的矩阵，你可以将它看成一个4×4的图像。下面来看如何计算得到这个4×4矩阵的。为了计算第一个元素，在4×4左上角的那个元素，使用3×3的过滤器，将其覆盖在输入图像。然后进行元素乘法运算，所以$\\begin{bmatrix} 3 \\times 1 & 0 \\times 0 & 1 \\times \\left(1 \\right) \\\\ 1 \\times 1 & 5 \\times 0 & 8 \\times \\left( - 1 \\right) \\\\ 2 \\times1 & 7 \\times 0 & 2 \\times \\left( - 1 \\right) \\\\ \\end{bmatrix} = \\begin{bmatrix}3 & 0 & - 1 \\\\ 1 & 0 & - 8 \\\\ 2 & 0 & - 2 \\\\\\end{bmatrix}$，然后将该矩阵每个元素相加得到最左上角的元素，即$3+1+2+0+0 +0+(-1)+(-8) +(-2)=-5$，这就是我们第一卷积值，然后依次类推求其他卷积值。\n",
    "<img src=\"images/5f9c10d0986f003e5bd6fa87a9ffe04b.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 垂直边缘检测\n",
    "1. 抽象特征\n",
    "2. 加强特征\n",
    "<img src=\"images/0c8b5b8441557b671431d515aefa1e8a.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 边缘检测-过滤器-filter\n",
    "## 区分敏感变化\n",
    "<img src=\"images/6a248e5698d1f61ac4ba0238363c4a37.png\">\n",
    "\n",
    "## 水平边缘过滤器\n",
    "1. 渐变边缘\n",
    "<img src=\"images/f4adb9d91879e1c1aaef9bc9e244c64a.png\">\n",
    "\n",
    "## 过滤器作为参数\n",
    "<img src=\"images/f889ad7011738a23d78070e8ed2df04e.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 填充-padding\n",
    "假设输入图片的大小为$n×n$，而过滤器的大小为$ f×f$，则卷积后的输出图片大小为$ (n−f+1)×(n−f+1)$。\n",
    "\n",
    "**这样就有两个问题：**\n",
    "1. 每次卷积运算后，输出图片的尺寸缩小；\n",
    "2. 原始图片的角落、边缘区像素点在输出中采用较少，输出图片丢失边缘位置的很多信息。\n",
    "\n",
    "为了解决这些问题，可以在进行卷积操作前，对原始图片在边界上进行填充（Padding），以增加矩阵的大小。通常将0作为填充值。填充模式分为如下三种：\n",
    "<table>\n",
    "    <tr>\n",
    "        <th>full</th><th>same</th><th>valid</th>\n",
    "    </tr>\n",
    "    <tr><td>从filter和image刚相交开始做卷积</td><td>$(n+2p-f+1)×(n+2p-f+1)\\\\p=\\frac{f-1}{2}$<br>same模式也是最常见的模式，因为这种模式可以在前向传播的过程中让特征图的大小保持不变</td><td>当filter全部在image里面的时候，进行卷积运算，可见filter的移动范围较same更小了</td></tr>\n",
    "    <tr><td><img src=\"images/20180515205400757.png\"></td>\n",
    "        <td><img src=\"images/20180515205624201.png\"></td>\n",
    "        <td><img src=\"images/20180515205946981.png\"></td>\n",
    "    </tr>\n",
    "</table>\n",
    "\n",
    "**过滤器的大小：**\n",
    "1. 如果$f$是一个偶数，那么你只能使用一些不对称填充。只有$f$是奇数的情况下，**Same**卷积才会有自然的填充，我们可以以同样的数量填充四周，而不是左边填充多一点，右边填充少一点，这样不对称的填充。\n",
    "2. 一个奇数维过滤器，比如3×3或者5×5的，它就有一个中心点。有时在计算机视觉里，如果有一个中心像素点会更方便，便于指出过滤器的位置。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 卷积步长-Stride\n",
    "卷积过程中，有时需要通过填充来避免信息损失，有时也需要通过设置步长（Stride）来压缩一部分信息。步长表示滤波器在原始图片的水平方向和垂直方向上每次移动的距离。之前，步长被默认为 1。\n",
    "\n",
    "**输出图片大小的计算：**\n",
    "设步长为 s，填充长度为 p，输入图片大小为 n×n，滤波器大小为 f×f，则卷积后图片的尺寸为：\n",
    "$$\\left\\lfloor\\frac{n+2p - f}{s} + 1 \\right\\rfloor\\times\\left\\lfloor\\frac{n+2p - f}{s} + 1\\right\\rfloor,\\left\\lfloor\\right\\rfloor表示向下取整的符号$$\n",
    "\n",
    "**例**\n",
    "\n",
    "下面例子的输出大小为：$\\frac{7+2\\times0 - 3}{2} + 1=3$\n",
    "<img src=\"images/9cacda308d53adb7d154a3b259569f45.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 高维卷积\n",
    "如果我们想要对三通道的 RGB 图片进行卷积运算，那么其对应的滤波器组也同样是三通道的。过程是将每个单通道（R，G，B）与对应的滤波器进行卷积运算求和，然后再将三个通道的和相加，将 27 个乘积的和作为输出图片的一个像素值。\n",
    "\n",
    "不同通道的滤波器可以不相同。例如只检测 R 通道的垂直边缘，G 通道和 B 通道不进行边缘检测，则 G 通道和 B 通道的滤波器全部置零。当输入有特定的高、宽和通道数时，滤波器可以有不同的高和宽，**但通道数必须和输入一致**。\n",
    "<img src=\"images/9b0b0e9062f8814a6a462ea64449f89e.png\">\n",
    "\n",
    "如果想同时检测垂直和水平边缘，或者更多的边缘检测，可以增加更多的滤波器组。例如设置第一个滤波器组实现垂直边缘检测，第二个滤波器组实现水平边缘检测。设输入图片的尺寸为$n\\times n\\times n_c（n_c为通道数）$，滤波器尺寸为$f\\times f\\times n_c$，则卷积后的输出图片尺寸为 $(n−f+1)\\times (n−f+1)\\times n^{'}_c,n^{'}_c$为滤波器组的个数。\n",
    "<img src=\"images/794b25829ae809f93ac69f81eee79cd1.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 卷积网络-前向传播\n",
    "## 公式\n",
    "**神经网络：**\n",
    "\n",
    "$z^{[l]} = W^{[l]}a^{[l-1]} + b^{[l]}$\n",
    "\n",
    "$a^{[l]} = g(z^{[l]})$\n",
    "\n",
    "**卷积网络：**\n",
    "\n",
    "$z^{[l]} = c(W^{[l]},a^{[l-1]}) + b^{[l]}$\n",
    "\n",
    "$a^{[l]} = g(z^{[l]})$\n",
    "\n",
    "<img src=\"images/1236aad1426bd324964091f1d9db2587.png\">\n",
    "\n",
    "## 形状\n",
    "用$f^{[l]}$表示过滤器大小，即$f×f$，上标$\\lbrack l\\rbrack$表示$l$层中过滤器大小为$f×f$。通常情况下，上标$\\lbrack l\\rbrack$用来标记$l$层。用$p^{[l]}$来标记**padding**的数量，**padding**数量也可指定为一个**valid**卷积，即无**padding**，或是**same**卷积，即选定**padding**，如此一来，输出和输入图片的高度和宽度就相同了。用$s^{[l]}$标记步幅。\n",
    "\n",
    "输入矩阵$A^{[l-1]}$的大小为$m\\times n_{H}^{\\left\\lbrack l - 1 \\right\\rbrack} \\times n_{W}^{\\left\\lbrack l - 1 \\right\\rbrack} \\times n_{c}^{\\left\\lbrack l - 1\\right\\rbrack}$\n",
    "\n",
    "输出矩阵$A^{[l]}$的大小为$m\\times n_{H}^{[l]} \\times n_{W}^{[l]} \\times n_{c}^{[l]}$\n",
    "\n",
    "$n_{c}^{[l]}$为$l$层过滤器个数,$m$为样本个数\n",
    "\n",
    "其中，$n_{H}^{[l]} = \\lfloor\\frac{n_{H}^{\\left\\lbrack l - 1 \\right\\rbrack} +2p^{[l]} - f^{[l]}}{s^{[l]}} +1\\rfloor，n_{W}^{[l]}依次类推$\n",
    "\n",
    "过滤器集合$W^{[l]}$的大小为$ n_{c}^{[l]}\\times f^{[l]} \\times f^{[l]} \\times  n_{c}^{[l - 1]}$\n",
    "\n",
    "截距项$b^{[l]}$的大小为$n_{c}^{[l]}\\times 1\\times 1\\times 1$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 卷积网络示例\n",
    "$\n",
    "\\large\\underset{l=0}{\\underbrace{\\left [ m,n_H=39,n_W=39,n_c=3 \\right ]}}\\xrightarrow[f_H=f_W=3,s=1,p=0,filters=10]{}\\\\\n",
    "\\large\\underset{l=1}{\\underbrace{\\left [ m,n_H=37,n_W=37,n_c=10 \\right ]}}\\xrightarrow[f_H=f_W=5,s=2,p=0,filters=20]{}\\\\\n",
    "\\large\\underset{l=2}{\\underbrace{\\left [ m,n_H=17,n_W=17,n_c=20 \\right ]}}\\xrightarrow[f_H=f_W=5,s=2,p=0,filters=40]{}\\\\\n",
    "\\large\\underset{l=3}{\\underbrace{\\left [ m,n_H=7,n_W=7,n_c=40 \\right ]}}\\to\\\\\n",
    "\\large\\underset{l=3}{\\underbrace{\\left [ m,7\\times 7\\times 40=1960\\right ]}}\\xrightarrow[softmax]{}\\hat y\n",
    "$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 卷积神经网络隐含层的类型\n",
    "1. 卷积层：Conv，有卷积核产生的激活层;\n",
    "2. 池化层：POOL，特殊的卷积层，**不含参数且不同的通道/depth上分开执行的卷积核**；\n",
    "3. 全连接层：FC-Full Connection，普通的神经网络隐含层，通常作为输出层；"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 池化层-POOL\n",
    "1. POOL池化层，最大的特定即与卷积层的区别是，不含参数且不同的通道/depth上分开执行的卷积核；\n",
    "2. **通道数$n_c^{[l]}=n_c^{[l-1]}$**，\n",
    "2. 同卷积层一样，包含超参数：\n",
    "    1. 滤器filter大小$[f_W,f_H]$，\n",
    "    2. 填充padding，$p$,\n",
    "    3. 步长Stride，$s$,\n",
    "\n",
    "## 类型\n",
    "1. 平均池化层-average pooling\n",
    "2. 最大池化层-max pooling\n",
    "3. etc.\n",
    "\n",
    "## 最大池化层\n",
    "max pooling的操作如下图所示：整个图片被分割成若干个同样大小的小块（pooling size）。每个小块内只取最大的数字，再舍弃其他节点后，保持原有的平面结构得出output。\n",
    "<img src=\"images/ed201fd1f1132c66648ca9cb86d5cd4b.jpg\">\n",
    "\n",
    "max pooling在不同的depth上是分开执行的，且不需要参数控制。Max pooling的主要功能是downsampling-降采样，却不会损坏识别结果。\n",
    "<img src=\"images/f67c33653a031a4f499bf523b6e45283.jpg\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 卷积神经网络示例\n",
    "卷积神经网络大致就是covolutional layer, pooling layer, ReLu layer, fully-connected layer的组合，例如下图所示的结构。\n",
    "<img src=\"images/93fba35bdf177a53f1635288c987c37a.jpg\">\n",
    "\n",
    "**形状和参数量，例：**\n",
    "<img src=\"images/b715a532e64edaa241c27eef9fdc9bfd.png\">\n",
    "**注：**\n",
    "1. 令卷积核的不同通道的参数相等；"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 为什么使用卷积？\n",
    "## 优势\n",
    "和只用全连接层相比，卷积层的两个主要优势在于参数共享和稀疏连接。\n",
    "\n",
    "1. 参数共享，不同位置的数据使用相同的卷积核；\n",
    "2. 稀疏链接，一次卷积只有部分数据参与；\n",
    "\n",
    "## 不变性\n",
    "1. 平移不变性\n",
    "2. 旋转和视角不变性\n",
    "3. 尺寸不变性\n",
    "\n",
    "<img src=\"images/f036d884670284321c02d1afdbec7859.jpg\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 卷积神经网络（CNN）的模块练习\n",
    "现在我们练习使用Numpy实现卷积层（CONV）和池化层（POOL）的前向传播（forward propagation）和反向传播（ backward propagation）。\n",
    "<img src=\"images/78b7fb4d5a0e16b587f448d6e5dc84c9.png\">\n",
    "## 环境准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import h5py\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.rcParams['figure.figsize'] = (5.0, 4.0) \n",
    "plt.rcParams['image.interpolation'] = 'nearest'\n",
    "plt.rcParams['image.cmap'] = 'gray'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 前向传播\n",
    "### 卷积层\n",
    "卷积层利用卷积核（卷积过滤器）将输入的数据卷积为不同形状的输出数据。\n",
    "<img src=\"images/7caba155fe2eec89ee0b53bb1b4d9f01.png\"  style=\"width:350px;height:200px;\">\n",
    "\n",
    "### Zero-Padding\n",
    "零填充的作用是将零值填充到图像的边缘。\n",
    "<img src=\"images/9a91e423233cb015d469d42c56b20baa.png\" style=\"width:600px;height:400px;\">\n",
    "\n",
    "**作用：**\n",
    "1. 使数据经过卷积之后高度和宽度保持不变，例如，**same**模式；\n",
    "2. 如果没有填充的话，边缘数据的影响权重相对来说较小，填充可以解决这个问题；"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def zero_pad(X, pad):\n",
    "    \"\"\"\n",
    "    对四维数据进行padding\n",
    "    \"\"\"\n",
    "    X_pad = np.pad(X, ((0, 0),(pad, pad),(pad, pad),(0, 0)), 'constant', constant_values=0)\n",
    "    return X_pad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "x=np.random.randn(4,3,3,2)\n",
    "x_pad=zero_pad(x,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((4, 3, 3, 2), (4, 7, 7, 2))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape,x_pad.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x2256f562fd0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAC7CAYAAACNSp5xAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEEhJREFUeJzt3X+sX3V9x/Hna6XK8EKBgrZpq1cyQmRuWm0qpAnpQBbABkzGEkjEH9E0MeIQXVxxSc1ItrCFMHUYSAcViAQ0wLYOOhkGGJIpUrCFlgIWoulNy4plll5hksprf9wD+XLvt/dHz+d+z/fb83ok39xzvufT83l/7/fk1XPPj8+RbSIiol1+r+kCIiKi9xL+EREtlPCPiGihhH9ERAsl/CMiWijhHxHRQrXCX9Lxku6T9PPq53EHafc7SZur14Y6fUZERH2qc52/pH8AXrJ9laQ1wHG2/6pLu1HbQzXqjIiIguqG/zPAStu7JS0EHrR9Spd2Cf+IiD5S95j/u2zvBqh+vvMg7Y6UtEnSTyR9vGafERFR0xFTNZD0Q2BBl0V/PYN+3m17l6STgPslPWn7uS59rQZWA8ydO/fDxx9//Ay66F8LFnT79Q2mV199tekSinn22Wd/ZfvEXvd79NFHe/78+b3uNlpi79697N+/X1O1mzL8bX/0YMsk/Y+khR2HffYcZB27qp/PS3oQWApMCH/b64B1AAsWLPAll1wyVXkDYc2aNU2XUMzWrVubLqGYlStX/rKJfufPn8/atWub6Dpa4Morr5xWu7qHfTYAn6qmPwX82/gGko6T9PZq+gRgBfBUzX4jIqKGuuF/FXC2pJ8DZ1fzSFom6YaqzfuATZK2AA8AV9lO+EdENGjKwz6Tsb0XOKvL+5uAz1XT/w38UZ1+IiKirNzhGxHRQgn/iIgWSvhH1CTpHEnPSNpR3eke0fcS/hE1SJoDfBs4FzgVuFjSqc1WFTG1hH9EPcuBHbaft/0acDtwQcM1RUwp4R9RzyJgZ8f8SPXeW0haXQ1xsml0dLRnxUUcTMI/op5ut9FPGC3R9jrby2wvGxrKGIfRvIR/RD0jwJKO+cXAroZqiZi2hH9EPY8CJ0t6r6S3ARcxNuxJRF+rdYdvRNvZPiDpUuBeYA6w3va2hsuKmFLCP6Im2xuBjU3XETETOewTEdFCCf+IiBZK+EdEtFDCPyKihRL+EREtlPCPiGihIuE/1ZC2kt4u6XvV8kckDZfoNyIiDk3t8J/mkLafBf7X9h8A/wj8fd1+IyLi0JXY85/OkLYXADdX03cAZ0nqNiBWRET0QInwn86Qtm+2sX0A2AfMH7+izmFvX3nllQKlRURENyXCfzpD2s542NujjjqqQGkREdFNifCfzpC2b7aRdAQwD3ipQN8REXEISoT/dIa03QB8qpq+ELjf9oQ9/4iI6I3a4V8dw39jSNvtwPdtb5N0paTzq2Y3AvMl7QC+DEy4HDRiUElaL2mPpK1N1xIxXUWGdO42pK3ttR3T/wf8eYm+IvrQTcC1wC0N1xExbbnDN6Im2w+Rc1gxYBL+ET3QeRnz6Oho0+VEJPwjeqHzMuahoaGmy4lI+EdEtFHCPyKihRL+ETVJug34MXCKpBFJn226poipFLnUM6LNbF/cdA0RM5U9/4iIFkr4R0S0UMI/IqKFEv4RES2U8I+IaKFc7RMRk9q+fXvxda5ZMzsD+27dOjsDqz733HOzst4mZc8/IqKFEv4RES2U8I+IaKEi4S/pHEnPSNohacLBPEmflvSipM3V63Ml+o2IiENT+4SvpDnAt4GzGXtQ+6OSNth+alzT79m+tG5/ERFRX4k9/+XADtvP234NuB24oMB6IyJilpS41HMRsLNjfgT4SJd2fybpDOBZ4HLbO8c3kLQaWP3G/NVXX12gvOatWrWq6RKKufTS/PEWcTgoseevLu953Py/A8O2/xj4IXBztxV1Pu2oQF0Rs07SEkkPSNouaZuky5quKWI6SoT/CLCkY34xsKuzge29tn9bzf4z8OEC/Ub0gwPAV2y/DzgN+IKkUxuuKWJKJcL/UeBkSe+V9DbgImBDZwNJCztmzwfK3zIY0QDbu20/Xk3vZ2zbXtRsVRFTq33M3/YBSZcC9wJzgPW2t0m6EthkewPwF5LOZ2wv6SXg03X7jeg3koaBpcAjXZa9eT5r/vz5Pa0ropsiY/vY3ghsHPfe2o7pK4ArSvQV0Y8kDQF3Al+y/fL45bbXAesAhoeHx58Ti+i53OEbUZOkuYwF/62272q6nojpSPhH1CBJwI3AdtvXNF1PxHQl/CPqWQFcApzZMXzJeU0XFTGVjOcfUYPth+l+r0tEX8uef0RECyX8IyJaKOEfEdFCCf+IiBZK+EdEtFCu9omISc3G0OqzNcz5bA05fvnll8/KepuUPf+IiBZK+EdEtFDCPyKihRL+EREtlPCPiGihhH9ERAsVCX9J6yXtkbT1IMsl6VuSdkh6QtKHSvQb0Q8kHSnpp5K2VA9x/5uma4qYSqk9/5uAcyZZfi5wcvVaDVxXqN+IfvBb4EzbHwA+CJwj6bSGa4qYVJHwt/0QY8/mPZgLgFs85ifAseMe6h4xsKrterSanVu98qjG6Gu9Oua/CNjZMT9SvRdxWJA0R9JmYA9wn+0JD3GP6Ce9Cv9uD7uYsGckabWkTZI29aCmiGJs/872B4HFwHJJ7+9c3rltj46Odl9JRA/1KvxHgCUd84uBXeMb2V5ne5ntZT2qK6Io278GHmTcObDObXtoaKiR2iI69Sr8NwCfrK76OQ3YZ3t3j/qOmFWSTpR0bDX9+8BHgaebrSpickVG9ZR0G7ASOEHSCPB1xk56Yft6YCNwHrADeAX4TIl+I/rEQuBmSXMY26H6vu27G64pYlJFwt/2xVMsN/CFEn1F9BvbTwBLm64jYiZyh29ERAsl/CMiWijhHxHRQgn/iIgWSvhHRLRQHuAeEZO65557iq9ztm50e/LJJ2dlvevXr5+V9TYpe/4RES2U8I+IaKGEf0RECyX8IyJaKOEfEdFCCf+IiBZK+EdEtFDCP6KA6jGOP5OUoZxjICT8I8q4DNjedBER05Xwj6hJ0mLgY8ANTdcSMV0J/4j6vgF8FXj9YA3yAPfoN0XCX9J6SXskbT3I8pWS9knaXL3Wlug3ommSVgF7bD82Wbs8wD36TamB3W4CrgVumaTNj2yvKtRfRL9YAZwv6TzgSOAYSd+1/YmG64qYVJE9f9sPAS+VWFfEILF9he3FtoeBi4D7E/wxCHo5pPPpkrYAu4C/tL1tfANJq4HVACeeeCI33XRTD8ubPYfTn/mzNWRuEyQ1XUJEY3p1wvdx4D22PwD8E/Cv3Rp1HhedN29ej0qLKMP2gzm0GYOiJ+Fv+2Xbo9X0RmCupBN60XdEREzUk/CXtEDV39iSllf97u1F3xERMVGRY/6SbgNWAidIGgG+DswFsH09cCHweUkHgFeBi2y7RN8RETFzRcLf9sVTLL+WsUtBIyKiD+QO34iIFurlpZ4RMYBeeOGFgVgnwJYtW2ZlvYej7PlHRLRQwj8iooUS/hERLZTwj4hooYR/REQLJfwjIloo4R8R0UK5zj+iAEm/APYDvwMO2F7WbEURk0v4R5TzJ7Z/1XQREdORwz4RES2U8I8ow8B/SnqseiLdW0haLWmTpE2jo6MNlBfxVjnsE1HGCtu7JL0TuE/S09WzrYGxp9QB6wCGh4cznHk0Lnv+EQXY3lX93AP8C7C82YoiJpfwj6hJ0jskHf3GNPCnwNZmq4qYXO3wl7RE0gOStkvaJumyLm0k6VuSdkh6QtKH6vYb0UfeBTwsaQvwU+Ae2z9ouKaISZU45n8A+Irtx6u9n8ck3Wf7qY425wInV6+PANdVPyMGnu3ngQ80XUfETNTe87e92/bj1fR+YDuwaFyzC4BbPOYnwLGSFtbtOyIiDk3RY/6ShoGlwCPjFi0CdnbMjzDxP4i3XA63b9++kqVFRESHYuEvaQi4E/iS7ZfHL+7yTyZc7mZ7ne1ltpfNmzevVGkRETFOkfCXNJex4L/V9l1dmowASzrmFwO7SvQdEREzV+JqHwE3AtttX3OQZhuAT1ZX/ZwG7LO9u27fERFxaEpc7bMCuAR4UtLm6r2vAe8GsH09sBE4D9gBvAJ8pkC/ERFxiGqHv+2H6X5Mv7ONgS/U7SsiIsrIHb4RES2U8I+IaKGEf0RECyX8IyJaKOEfEdFCCf+IiBZK+EfUJOlYSXdIeroa2vz0pmuKmEoe4xhR3zeBH9i+UNLbgKOaLihiKgn/iBokHQOcAXwawPZrwGtN1hQxHTnsE1HPScCLwHck/UzSDdWjHN+ic7jy0dHR3lcZMU7CP6KeI4APAdfZXgr8BlgzvlHncOVDQ0O9rjFigoR/RD0jwIjtNx5gdAdj/xlE9LWEf0QNtl8Adko6pXrrLOCpSf5JRF/ICd+I+r4I3Fpd6fM8GbI8BkDCP6Im25uBZU3XETETOewTEdFCJR7juETSA9WdjdskXdalzUpJ+yRtrl5r6/YbERGHrsRhnwPAV2w/Lulo4DFJ99kef9LrR7ZXFegvIiJqqr3nb3u37cer6f3AdmBR3fVGRMTsKXrMX9IwsBR4pMvi0yVtkfQfkv6wZL8RETEzGnu2eoEVSUPAfwF/a/uuccuOAV63PSrpPOCbtk/uso7VwOpq9hTgmSLFTe4E4Fc96KcXDpfP0qvP8R7bJ/agn7eQ9CLwy2k2H6TvdJBqhcGqdya1Tmu7LhL+kuYCdwP32r5mGu1/ASyz3fgvXtIm24fFZXqHy2c5XD5HCYP0uxikWmGw6p2NWktc7SPgRmD7wYJf0oKqHZKWV/3urdt3REQcmhJX+6wALgGelLS5eu9rwLsBbF8PXAh8XtIB4FXgIpc63hQRETNWO/xtPwxoijbXAtfW7WuWrGu6gIIOl89yuHyOEgbpdzFItcJg1Vu81mInfCMiYnBkeIeIiBZqbfhLOkfSM5J2SJrw8I1BIWm9pD2StjZdS13TGSqkLQZp+xzE703SnOrJa3c3XctUJB0r6Q5JT1e/49OLrLeNh30kzQGeBc5m7GEcjwIXdxmSou9JOgMYBW6x/f6m66lD0kJgYedQIcDHB/F7qWPQts9B/N4kfZmxkViP6fdhZyTdzNjwODdUw4YfZfvXddfb1j3/5cAO289XD9y+Hbig4ZoOie2HgJearqOEDBXypoHaPgfte5O0GPgYcEPTtUylukH2DMYup8f2ayWCH9ob/ouAnR3zI/TxxtpGUwwVcrgb2O1zQL63bwBfBV5vupBpOAl4EfhOdZjqBknvKLHitoZ/t0tT23f8q09VQ4XcCXzJ9stN19OAgdw+B+F7k7QK2GP7saZrmaYjGHsm9HW2lwK/AYqcA2pr+I8ASzrmFwO7GqolOlRDhdwJ3Dp+jKgWGbjtc4C+txXA+dUQM7cDZ0r6brMlTWoEGLH9xl9SdzD2n0FtbQ3/R4GTJb23OoFyEbCh4ZpabzpDhbTEQG2fg/S92b7C9mLbw4z9Xu+3/YmGyzoo2y8AOyWdUr11FlDkRHorw9/2AeBS4F7GTk593/a2Zqs6NJJuA34MnCJpRNJnm66phjeGCjmz46lv5zVdVK8N4PaZ7212fRG4VdITwAeBvyux0lZe6hkR0Xat3POPiGi7hH9ERAsl/CMiWijhHxHRQgn/iIgWSvhHRLRQwj8iooUS/hERLfT/Oz1HxB9j7mMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig,axs=plt.subplots(1,2)\n",
    "axs[0].imshow(x[0,:,:,0])\n",
    "axs[1].imshow(x_pad[0,:,:,0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 单步卷积操作\n",
    "这里我们实现一个卷积操作函数，实现对某个待卷积数据的卷积操作，并且得到卷积值。\n",
    "<img src=\"images/9b5d8b7ba3f91d34f7e2b6159ad4abf1.gif\" style=\"width:500px;height:300px;\">\n",
    "\n",
    "**注：**\n",
    "1. 矩阵点乘在求和"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def  conv_single_step(a_slice_prev,W,b):\n",
    "    s=np.sum(np.multiply(a_slice_prev,W))+b\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.362141595840531"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv_single_step(np.random.randn(4,4,3),np.random.randn(4,4,3),np.random.randn())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 卷积操作\n",
    "在单一样本的卷积神经网络前向传播过程中，卷积层每个卷积核产生一个二维数据，多个卷积核的输出可以堆叠出和输入数据相同张量（3）的三维数据。\n",
    "\n",
    "$$ n_H = \\lfloor \\frac{n_{H_{prev}} - f + 2 \\times pad}{stride} \\rfloor +1 $$\n",
    "$$ n_W = \\lfloor \\frac{n_{W_{prev}} - f + 2 \\times pad}{stride} \\rfloor +1 $$\n",
    "$$ n_C = \\text{number of filters used in the convolution}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_forward(A_prev, W, b, hparameters):\n",
    "    # 获得输入数据的维度数据\n",
    "    (m, n_H_prev, n_W_prev, n_C_prev1) = A_prev.shape\n",
    "    # 卷积核的维度数据\n",
    "    (f, f, n_C_prev, n_C) = W.shape\n",
    "    \n",
    "    assert n_C_prev1==n_C_prev\n",
    "\n",
    "    stride = hparameters['stride']\n",
    "    pad = hparameters['pad']\n",
    "\n",
    "    n_H = 1 + int((n_H_prev + 2 * pad - f) / stride)# 输出数据的高\n",
    "    n_W = 1 + int((n_W_prev + 2 * pad - f) / stride)# 输出数据的宽\n",
    "\n",
    "    Z = np.zeros((m, n_H, n_W, n_C))\n",
    "\n",
    "    A_prev_pad = zero_pad(A_prev, pad)\n",
    "    \n",
    "    for i in range(m):                                 # 遍历样本\n",
    "        a_prev_pad = A_prev_pad[i]                     # 获取填充后的输入数据\n",
    "        for h in range(n_H):                           # 遍历高度\n",
    "            for w in range(n_W):                       # 遍历宽度\n",
    "                for c in range(n_C):                   # 遍历通道\n",
    "                    \n",
    "                    # 定位\n",
    "                    vert_start = h * stride\n",
    "                    vert_end = vert_start + f\n",
    "                    horiz_start = w * stride\n",
    "                    horiz_end = horiz_start + f\n",
    "                    \n",
    "                    # 获取待卷积数据\n",
    "                    a_slice_prev = a_prev_pad[vert_start:vert_end, horiz_start:horiz_end, :]\n",
    "                  \n",
    "                    # 单步卷积操作\n",
    "                    Z[i, h, w, c] = np.sum(np.multiply(a_slice_prev, W[:, :, :, c]) + b[:, :, :, c])\n",
    "    \n",
    "    # 核查形状\n",
    "    assert(Z.shape == (m, n_H, n_W, n_C))\n",
    "    \n",
    "    # 保存信息，以备反向传播使用\n",
    "    cache = (A_prev, W, b, hparameters)\n",
    "    \n",
    "    return Z, cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((10, 4, 4, 3), (2, 2, 3, 8), (1, 1, 1, 8))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A_prev=np.random.randn(10,4,4,3)\n",
    "W=np.random.randn(2,2,A_prev.shape[3],8)\n",
    "b=np.random.randn(1,1,1,W.shape[3])\n",
    "A_prev.shape,W.shape,b.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10, 7, 7, 8)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hparameters={'pad':2,\"stride\":1}\n",
    "Z,cache_conv=conv_forward(A_prev,W,b,hparameters)\n",
    "Z.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 池化操作\n",
    "$$ n_H = \\lfloor \\frac{n_{H_{prev}} - f}{stride} \\rfloor +1 $$\n",
    "$$ n_W = \\lfloor \\frac{n_{W_{prev}} - f}{stride} \\rfloor +1 $$\n",
    "$$ n_C = n_{C_{prev}}$$\n",
    "<table>\n",
    "<td>\n",
    "<img src=\"images/ef39f1ae883591583a782c61429f1f22.png\" style=\"width:500px;height:300px;\">\n",
    "<td>\n",
    "<td>\n",
    "<img src=\"images/f937baeb02afa8054d19effba6f1a3a8.png\" style=\"width:500px;height:300px;\">\n",
    "<td>\n",
    "</table>\n",
    "\n",
    "**同卷积操作的不同：**\n",
    "1. 对每一层通道分别做池化，输入通道=输出通道；\n",
    "2. 一般没有零填充；"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pool_forward(A_prev, hparameters, mode = \"max\"):    \n",
    "    # 获取输入数据的维度信息\n",
    "    (m, n_H_prev, n_W_prev, n_C_prev) = A_prev.shape\n",
    "    \n",
    "    # 获取参数信息\n",
    "    f = hparameters[\"f\"]\n",
    "    stride = hparameters[\"stride\"]\n",
    "    \n",
    "    # 计算输出数据的维度信息\n",
    "    n_H = int(1 + (n_H_prev - f) / stride)\n",
    "    n_W = int(1 + (n_W_prev - f) / stride)\n",
    "    n_C = n_C_prev\n",
    "    \n",
    "    # 初始输出数据内存\n",
    "    A = np.zeros((m, n_H, n_W, n_C))              \n",
    "\n",
    "    for i in range(m):                           # 遍历样本\n",
    "        for h in range(n_H):                     # 遍历高\n",
    "            for w in range(n_W):                 # 遍历宽\n",
    "                for c in range (n_C):            # 遍历通道\n",
    "                    \n",
    "                    # 定位\n",
    "                    vert_start = h * stride\n",
    "                    vert_end = vert_start + f\n",
    "                    horiz_start = w * stride\n",
    "                    horiz_end = horiz_start + f\n",
    "                    \n",
    "                    # 获取单步池化数据\n",
    "                    a_prev_slice = A_prev[i, vert_start:vert_end, horiz_start:horiz_end, c]\n",
    "                    \n",
    "                    # 不同类型的池化操作\n",
    "                    if mode == \"max\":\n",
    "                        A[i, h, w, c] = np.max(a_prev_slice)\n",
    "                    elif mode == \"average\":\n",
    "                        A[i, h, w, c] = np.mean(a_prev_slice)\n",
    "        \n",
    "    # 保存信息，以备反向传播使用\n",
    "    cache = (A_prev, hparameters)\n",
    "    \n",
    "    # 核查性质\n",
    "    assert(A.shape == (m, n_H, n_W, n_C))\n",
    "    \n",
    "    return A, cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_prev=np.random.randn(2,4,4,3)\n",
    "hparameters={\"stride\":1,\"f\":4}\n",
    "A,cache=pool_forward(A_prev,hparameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 1, 1, 3)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[[2.1656048 , 1.52669475, 1.05151787]]],\n",
       "\n",
       "\n",
       "       [[[1.04562939, 1.69453761, 1.43035074]]]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[[-0.13615407,  0.14597596, -0.38591061]]],\n",
       "\n",
       "\n",
       "       [[[-0.22439704, -0.02152892,  0.12799356]]]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A,cache=pool_forward(A_prev,hparameters,mode=\"average\")\n",
    "A"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 反向传播\n",
    "### 卷积层\n",
    "$h:输出数据的高度；\\\\w:输出数据的宽度；$\n",
    "#### dA\n",
    "$$ dA += \\sum _{h=0} ^{n_H} \\sum_{w=0} ^{n_W} W_c \\times dZ_{hw}$$\n",
    "#### dW\n",
    "$$ dW_c  += \\sum _{h=0} ^{n_H} \\sum_{w=0} ^ {n_W} a_{slice} \\times dZ_{hw}$$\n",
    "#### db\n",
    "$$ db = \\sum_h \\sum_w dZ_{hw} \\tag{3}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_backward(dZ, cache):\n",
    "    # 获取缓存信息\n",
    "    (A_prev, W, b, hparameters) = cache\n",
    "    \n",
    "    # 获取输入层的维度信息\n",
    "    (m, n_H_prev, n_W_prev, n_C_prev) = A_prev.shape\n",
    "    \n",
    "    # 获取卷积核的维度信息\n",
    "    (f, f, n_C_prev, n_C) = W.shape\n",
    "    \n",
    "    # 获取超参数信息\n",
    "    stride = hparameters['stride']\n",
    "    pad = hparameters['pad']\n",
    "    \n",
    "    # 获取输出层微分数据的信息\n",
    "    (m, n_H, n_W, n_C) = dZ.shape\n",
    "    \n",
    "    # 初始化本次对应微分项目的内存\n",
    "    dA_prev = np.zeros((m, n_H_prev, n_W_prev, n_C_prev))                           \n",
    "    dW = np.zeros((f, f, n_C_prev, n_C))\n",
    "    db = np.zeros((1, 1, 1, n_C))\n",
    "\n",
    "    # 零填充\n",
    "    A_prev_pad = zero_pad(A_prev, pad)\n",
    "    dA_prev_pad = zero_pad(dA_prev, pad)\n",
    "    \n",
    "    for i in range(m):                       # 遍历样本\n",
    "        a_prev_pad = A_prev_pad[i]\n",
    "        da_prev_pad = dA_prev_pad[i]\n",
    "        \n",
    "        for h in range(n_H):                   # 遍历高度\n",
    "            for w in range(n_W):               # 遍历宽度\n",
    "                for c in range(n_C):           # 遍历不同卷积核\n",
    "                    \n",
    "                    # 定位\n",
    "                    vert_start = h * stride\n",
    "                    vert_end = vert_start + f\n",
    "                    horiz_start = w * stride\n",
    "                    horiz_end = horiz_start + f\n",
    "                    \n",
    "                    # 获取定位的输入的卷积区域\n",
    "                    a_slice = a_prev_pad[vert_start:vert_end, horiz_start:horiz_end, :]\n",
    "\n",
    "                    # 更新微分值\n",
    "                    da_prev_pad[vert_start:vert_end, horiz_start:horiz_end, :] += W[:,:,:,c] * dZ[i, h, w, c]\n",
    "                    dW[:,:,:,c] += a_slice * dZ[i, h, w, c]\n",
    "                    db[:,:,:,c] += dZ[i, h, w, c]\n",
    "                    \n",
    "        dA_prev[i, :, :, :] = dA_prev_pad[i, pad:-pad, pad:-pad, :]\n",
    "    # 核查\n",
    "    assert(dA_prev.shape == (m, n_H_prev, n_W_prev, n_C_prev))\n",
    "    \n",
    "    return dA_prev, dW, db"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "dA,dW,db=conv_backward(Z,cache_conv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((10, 4, 4, 3), (2, 2, 3, 8), (1, 1, 1, 8))"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dA.shape,dW.shape,db.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 池化层\n",
    "#### 最大池化\n",
    "$$ X = \\begin{bmatrix}\n",
    "1 && 3 \\\\\n",
    "4 && 2\n",
    "\\end{bmatrix} \\quad \\rightarrow  \\quad M =\\begin{bmatrix}\n",
    "0 && 0 \\\\\n",
    "1 && 0\n",
    "\\end{bmatrix}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_mask_from_window(x):\n",
    "    return x==np.max(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[-1.18580165, -0.37167479, -0.89046327],\n",
       "        [ 0.7089688 , -0.67326325, -1.2156681 ]]),\n",
       " array([[False, False, False],\n",
       "        [ True, False, False]]))"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x=np.random.randn(2,3)\n",
    "x,create_mask_from_window(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 平均池化\n",
    "$$ dZ = 1 \\quad \\rightarrow  \\quad dZ =\\begin{bmatrix}\n",
    "1/4 && 1/4 \\\\\n",
    "1/4 && 1/4\n",
    "\\end{bmatrix}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def distribute_value(dz, shape):\n",
    "    (n_H, n_W) = shape\n",
    "    average = dz / (n_H * n_W)\n",
    "    a = np.ones(shape) * average\n",
    "    return a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.5, 0.5],\n",
       "       [0.5, 0.5]])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "distribute_value(2,(2,2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 池化反向传播"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pool_backward(dA, cache, mode = \"max\"):    \n",
    "    (A_prev, hparameters) = cache\n",
    "    \n",
    "    stride = hparameters['stride']\n",
    "    f = hparameters['f']\n",
    "\n",
    "    m, n_H_prev, n_W_prev, n_C_prev = A_prev.shape\n",
    "    m, n_H, n_W, n_C = dA.shape\n",
    "    \n",
    "    dA_prev = np.zeros_like(A_prev)\n",
    "    \n",
    "    for i in range(m):                     \n",
    "        a_prev = A_prev[i]\n",
    "        \n",
    "        for h in range(n_H):                  \n",
    "            for w in range(n_W):         \n",
    "                for c in range(n_C):  \n",
    "\n",
    "                    vert_start = h * stride\n",
    "                    vert_end = vert_start + f\n",
    "                    horiz_start = w * stride\n",
    "                    horiz_end = horiz_start + f\n",
    "\n",
    "                    if mode == \"max\":\n",
    "                        \n",
    "                        a_prev_slice = a_prev[vert_start:vert_end, horiz_start:horiz_end, c]\n",
    "                        \n",
    "                        mask = create_mask_from_window(a_prev_slice)\n",
    "                        \n",
    "                        dA_prev[i, vert_start: vert_end, horiz_start: horiz_end, c] += mask * dA[i, vert_start, horiz_start, c]\n",
    "                        \n",
    "                    elif mode == \"average\":\n",
    "                        da = dA[i, vert_start, horiz_start, c]\n",
    "\n",
    "                        shape = (f, f)\n",
    "                        \n",
    "                        dA_prev[i, vert_start: vert_end, horiz_start: horiz_end, c] += distribute_value(da, shape)\n",
    "    assert(dA_prev.shape == A_prev.shape)\n",
    "    \n",
    "    return dA_prev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_prev=np.random.randn(5,5,3,2)\n",
    "hparameters={\"stride\":1,\"f\":2}\n",
    "A,cache=pool_forward(A_prev,hparameters)\n",
    "\n",
    "dA=np.random.randn(5,4,2,2)\n",
    "\n",
    "dA_prev=pool_backward(dA,cache,mode=\"max\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 5, 3, 2)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dA_prev.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "dA_prev=pool_backward(dA,cache,mode=\"average\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 5, 3, 2)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dA_prev.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow-CNN\n",
    "前一节的练习是为了让我们更加了解卷积神经网络的原理和实现细节。其中，讲到的所有CNN的功能逻辑，都在现代成熟框架中得以高效的实现，后面我们用TensorFlow实现一个CNN模型。\n",
    "<img src=\"images/6c8d61508321ac444175370124200350.png\" style=\"width:800px;height:300px;\">\n",
    "\n",
    "## 环境准备\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import h5py\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy\n",
    "from PIL import Image\n",
    "from scipy import ndimage\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import ops\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset():\n",
    "    train_dataset = h5py.File('datasets/train_signs.h5', \"r\")\n",
    "    train_set_x_orig = np.array(train_dataset[\"train_set_x\"][:]) # your train set features\n",
    "    train_set_y_orig = np.array(train_dataset[\"train_set_y\"][:]) # your train set labels\n",
    "\n",
    "    test_dataset = h5py.File('datasets/test_signs.h5', \"r\")\n",
    "    test_set_x_orig = np.array(test_dataset[\"test_set_x\"][:]) # your test set features\n",
    "    test_set_y_orig = np.array(test_dataset[\"test_set_y\"][:]) # your test set labels\n",
    "\n",
    "    classes = np.array(test_dataset[\"list_classes\"][:]) # the list of classes\n",
    "    \n",
    "    train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))\n",
    "    test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))\n",
    "    \n",
    "    return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_orig, Y_train_orig, X_test_orig, Y_test_orig, classes = load_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y=1\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJztfWmMZMeR3hd19H1M99ycGXJG4ogUKfHaEUWJ8oqSSJmWdlcLQzL2gEEbBPhnbWjhNVaSDRi7hg1If3bXPwwBhLVe/pB17CFTIHRRXNEraiWKQ1EUj1neM5z76mN6erqr60j/qOqXEfEqs19Vd1dRevEBjc5XmS8z33uV9SIyIr4g5xwMBkO+UOj3BAwGQ+9hC99gyCFs4RsMOYQtfIMhh7CFbzDkELbwDYYcwha+wZBDrGvhE9G9RPQSEb1KRJ/dqEkZDIbNBXXrwENERQAvA7gHwAkATwH4Xefcixs3PYPBsBkorePc2wG86px7HQCI6KsAPgEguPCnp6fcvj1Xta8k8mX9W0TYAHTRSeSUeG8bMuG3Hrq+rI2+HxvgbbrpDqsuchRtui4cP3ESMzMza97w9Sz8PQCOs+MTAN4bO2Hfnqvw7Ye/DkCu89QH6kY40ZjaFpuH/gNKD9B2rHS7wJxUH/HTsn7R9QX44+6/C9Sm1GaoLvpwHfXB23X3C+qyHUQRakmRLlzqC8iKEQlZ1jVUXWRWGaVu3oxSPyzN43/+8X+Zqa/16PjtHlnqCojoASI6TESHL87MrmM4g8GwUVjPG/8EgH3seC+AU7qRc+5BAA8CwM3vvtGt/jY4/bsR/QX2IHUUaukiryfxFkuNFpY8iNr3r1/wsV93CkkvumnX0jGbo7iW1MWwocKDNcTrLjKsugnyHocvjL959SMLdygbxoVqfp0u2Cp+712wSpwmbre+H5Gxg/ffBVul109nMuJ63vhPAThIRAeIaADA7wD45jr6MxgMPULXb3znXI2I/h2A7wIoAvhL59wLGzYzg8GwaViPqA/n3LcAfGuD5mIwGHqEdS38zYJWV1xAP0ppRlzJSm3bttfd09oVr1S6pKiK7CdwnbAD/T9k2YjpvvGtgGwbJ6ld7IA5IH2VGfX/YN/xPvjeQ3YNNqaFx06LXUvEkuRiunugi9QeRej7rdoJJb8Dc2EbmMuuwZBD2MI3GHKI/on6WlQRVq4ubVlRaa29MJRWCMKibdDUF3NCiUjzcTm9OzNa0GEl3PsaHorZTGBaXQipJ/oxUECtSI8XnkfWb8tm0Et26VckEdYk1GCdOJvFYW98gyGHsIVvMOQQtvANhhyi5zp+ogtGAmBSulhAfYmpVFrn5CYUJ0wrEd00YuVylHEvIDaviBkwq+Ka9V7F5hEPWMmsQQfnFQsWkq7PkT5iftaZ1dsN8YkOQ/rs6ko2i4x10b0dddzh/oK98Q2GHMIWvsGQQ/RU1HdgElDU/BCJh46IU464+FqI1LGKmGtd1FSWTWRPxeZLPUOdF7onYU8vHZcdMvmkp8jMYxlNQVlj+GNt0xJw+D5mNZXxOPho7H9E3M7+bQxPJC3Cs7qsDoQR257sI/K9ygB74xsMOYQtfIMhh+ib516KwihCDCFFqADRBLTUqEW5gAisjrP5wUGJ7JFmKZGSi5vd7ZjLjd9U1EjbHtLXHyYckU1jom2YjESoEpHrzBobE59RN/cxu0UlOkfRYyyoKGKxCfaffeu+U2dXe+MbDDmELXyDIYewhW8w5BB90/G1KhoiI9B1mVkpIu2kOUzVRWi+Q3pUzBMrSl6RQoCII9UqtqnQfu8htaUSOQoNHn0usT2ECGKW1VC7uLlXz6K9R2iKnjrmGRgZWla5tuXWgJHzYl6g7TuJfeeywN74BkMOYQvfYMgheh+k41oZRkh51kWkxuyikDwr1H80003UIS9gktEc6jFvtMjYoZwBcW+0DUBa72o3jcyqTxQpCTUcnCWHjqhn/JyIChmbRkyEz5p1JzPRRyw4S46csUMYEYfBYFgbtvANhhzCFr7BkEP0QccPfE4ZGqEDN9eUyy7vI3QQJ3xwoV5SUwq7blJkEyEYfBW3xQXHFlNM7ZuEDtQH0YgwXqMNbgGX5tTcw7p7o32zrrP2RvcJsnnURhGbRyNrRtyMY6VP7OzMNd/4RPSXRHSOiJ5nn00T0aNE9Err/1QXUzUYDH1CFlH/rwDcqz77LIDHnHMHATzWOjYYDL8kWFPUd879AxHtVx9/AsBdrfJDAB4H8Jk1R3NcIomJ8+FPYqa4qFknwH+R4oPnfWQUqaNiY5RMQfUY0CQ6EuI2gzzed66OI9FoGUPa+POMzzzcX/S8Lu5pRz5xoe9SB48haMmOJWVYn+Ne15t7O51zpwGg9X9Hl/0YDIY+YNN39YnoASI6TESHZ2ZnN3s4g8GQAd3u6p8lot3OudNEtBvAuVBD59yDAB4EgJvedWPUB8oja4BNZKs6s3tXt7J45BzhjaasCzFabrFzne1+pOcYmGL2HrJ3vQEOhS7ishlKrxXjy8sqAccIUtJtw0eh0aL8e5Eesuog6/Xm7PaN/00A97XK9wF4eF2zMBgMPUUWc95XAPwYwHVEdIKI7gfweQD3ENErAO5pHRsMhl8SZNnV/91A1Uc2eC4Gg6FH6LHnnkt0nxShZlC/lYgTavKRwlQTLqjvA9Rg/mIFFUEYiJhLc37EIsmyEYlkJajsVtWL7pWI7gMeeF0jNpZEipA1gOjWTrCLsNfkRlxm2ozLvjs6H8QGjNdpJ+arbzDkELbwDYYcovecey0xJ8pJFuPLEx54Ea87fV5AFlo+/aY4XnrjSFIuFgdE3fCB65PywI69fqyivI3SCpXNNNn8IOBemDovdJAaPdIwG8NGzBuyO5m4OzNXZkq8iPk05llHUTWLqUWxALINsNMJdTJyv9f7KOyNbzDkELbwDYYcwha+wZBD9C9NdjRkK/KByMnWQfQSw8ripaR88onvi7riFV/nqjVR13jxF0l5y613JOVtN71XDjAwFB486hLM22Uz+6X1wNBOR2zfJGz6DOmfqqpdr4HPY9roJpBVZI1WjERsZt1fCHa41onc1Be53/wtnY4q3WAiDoPB8KsHW/gGQw7Re869DJ57azA38ANZxUSmGC9b5crlpHzq1CnRbrBeT8pDA+r21BaT4uI//n1SXllaEs123f7BpFwcHJZ9xDzh2E2IkoDE+B4ySnyC0GQTPMlc8Fo6UPH484yQ4mUmZwnWrOFRmZlgI0YWki10NPYkYlqLifoGg2FN2MI3GHKInov6jVYKrfRudOduYDHvvJQqwcpDE9NJuTqxXbQ7/YLfuZ8aHxV1I0PlpFysVP05Tz2hxvKBPrvf+2FRVygPBq+AggQe3YWN8FRe3aoEssPs5BUIeMxFqatTzpwB8T4VABMJ6spm6BG1sYxiMcR2/6UHZDc9rjFAh98Re+MbDDmELXyDIYewhW8w5BC91fGdg2u0ovOooeoyEj6w9NppHZN5QKWczJgOV/KXff0H7xbtvn/0aFI+rViBJwZ9tN74kNfVV69pFWd/9mRSLo6MibodN7/fz6NYlHMUJrCMFJKZHeEixqwY62csJCykg6sPYmQe5CL7EME5ak9D1keUETSrWa4TtL9XOqXY+vuW/a/ZdA3YG99gyCFs4RsMOUTPg3Qaq5x2pH9zuNgYk1+9ipAS9fl5uo+AtLltzz7R7LaPfSIp//Cvvyy7qHmvvhLz8HPLy6LdCmu38uP/J+rKzJQ49bZ3Buef1fSUlm1D6kJYPM4qUUYc69qc0l6Ep1SG4LBKEyYBCZO4pE1x7e9kJ6FCQpUID63GitkEI9cZC6baMPXE3vgGQy5hC99gyCFs4RsMOUTPXXbrLUWFnDLnZSWhECa7CDGEJpdgewqxCKgDN747KV88L019z37v20m57rxev2W4LNoN8jleWRB1Z572UX3lkRFRN7LrGn+enD2yQvD9i5qYXqkQUK7TnB+RaMjAlGNRZFmj7NKV0bC1bIiZTyM3SwwdVcJjdSH35oiDs55ihxsAWVJo7SOiHxDRESJ6gYg+3fp8mogeJaJXWv+nOhrZYDD0DVlE/RqAP3LOvRPAHQD+gIhuAPBZAI855w4CeKx1bDAYfgmQJXfeaQCnW+UFIjoCYA+ATwC4q9XsIQCPA/jMGn2hUW9F5xUiJBqpM7l4HzGtRLy7GkK14H1IlYOLWu9+352ibnHBE3j84u8fZX1Ljr2rd40n5cFByc2/NHMxKZ/8yfdE3d4P/GZSHt62y1ekLJMRE1hQTYrwyEVD5rrjyBO8Gd0FF0aiFWMzCl9n3PMto6i8IWGOXfbRL3MeEe0HcCuAJwHsbP0orP447Ni4aRkMhs1E5oVPRGMA/hbAHzrnLq3Vnp33ABEdJqLDs7Nz3czRYDBsMDItfCIqo7nov+yc+7vWx2eJaHerfjeAc+3Odc496Jw75Jw7NDW1ZSPmbDAY1ok1dXxqKlpfAnDEOfdnrOqbAO4D8PnW/4ezDJjoYB2wnGTlNZfun41gOzS8S22DlQGgzo/Vz+KvfeiupLzECDvPvvhz0W656tl5Biqyk3rdH1dPyrx9jSe+lZT3f+i3k/Lg5FbRLiuxYiydtrhXETdUREguZdrwMLrV98N6fYSEMmZuy0r0GWsW3Q8JV8jb2Llpr1kTYzyKdNkGWez4dwL41wCeI6LVb/h/QnPBf52I7gfwJoBPdTa0wWDoF7Ls6j+B8O/JRzZ2OgaDoRfoY5psjbDY2ClneLs+ZH/hIxmJJetKZW+au/3ue5LyDy5Jwo7zZ4+z/qTKUSp4UV+b+hrHjybl4z/25sL9H/wt2ccw9/jLKOOl3e6CfUS9+kItU/0HXfdUO1YVSXEtT1Hzpcg8siKrRS0anRduFzxJnUiRux/3XoxVpmG++gZDDmEL32DIIfog6qv/qxCbzJ3ISd2AMzcUg1VwcsffseMhFmDza3fJYJ7H//YrSXlmuSLqJod9Sq3aclXUVev+umsvv5iUB7dK36i9t32QTV8/Qh5Uk5WjLbx7LMg2IuJ86pkFdvJjjm+pmKuNZJ4ID9scO1DurNNIAFJkcArs5Hcyj07vlL3xDYYcwha+wZBD2MI3GHKIHvPqA40WB33aUtNgzbJpN2myTa7pFIJteYRfLD0ZGtIU12D8+Q1Gtjm9c5do9/bb3puUn/+H76v+/byGCnJ/ociOS2wil19/TrRb2O0JQif2HkQY3e6VcPNYmFEzmrMuRFCZ3fVNDJiZpyWzspv93mT1HI2PxnV37dXXfqMjvQ8RNvV1ui9hb3yDIYewhW8w5BC9N+e1xJWUlBiQLlVVhMccoAarLcQiT/hBMdiMlKmPeNovXlYX846bbk7KR195SdSdevP1pHxw5zZRNz7uTX07tnn+/WJR3pBLL/oUXYPjkvFscIIF9MTsaBH7kgy+CROfxBDi8kiLpMJ1LzJH3ix6MbpxllaKET9C5hHVVFz7j2OnKFBkvqHnslaf7WBvfIMhh7CFbzDkELbwDYYcose58xzTz8J2NK3Dhc174Y0C5+RvWkG0inDzM5OadlEtFNtrUnXUxPHwyGhS/sA994q6Hz7yjaTcKMrzpqYnk3J5wD+agvp5ri3MJOULz/1I1O045N2HS4OjCCIjB7y05mV1AY4E53XkhpvR/Tijm2tUK94YJpj1I+O2jL6Nq/sSWadgb3yDIYewhW8w5BB9i85LRXNlTBEtnZxiZig9bMALTLXj4mxBR74RVxH4b6YcjQXZYWrHTlH33ru96P9PP5S8+nMLPt3WMBP1lQOhEL+vnHhN1F0Y8pz+O27+QFIulGSar6iNTZiUsr0buuXOz47u5OiQahGbbvRSuhTns0bdRSgOoyGEq0sh62OwN77BkEPYwjcYcoiei/qNkKwk5FnNqcZqxK5nJJgi5QUWErW0zBRWFwqFEqsLe5zVKj6Tbr2yJOq2797t694jU3RdeP4nSXn0sqfvHipJbr5Cyf9eF5W6s/TG80n50qgX+8fffrNox0V/fZ0U2lruVpyPBPpIj81sO+0x77+0ChkYK9JnlCwkXBX//kX6z6wmbcSzaMHe+AZDDmEL32DIIWzhGww5RB959WMMGGENLO74xfX4FGsEaxXW42OhWGJ/gfV39qVfiHbnXmYptWrLom54fCIpj+85IOqqe9+WlBfP+fRaNKjmwRz+6ktXRN3kiNfrV448nZQXVlZEu7F33JaUi4ODok5a+vhFd/LMAogp0BkZNtLNIoQgwarspmDpERocKo4NMHcKFT+wHbJhnntENEREPyWiZ4noBSL609bnB4joSSJ6hYi+RkQDa/VlMBjeGsgi6lcAfNg5dzOAWwDcS0R3APgCgD93zh0EMAvg/s2bpsFg2EhkyZ3nAKzalsqtPwfgwwB+r/X5QwD+BMAXM/TX/N8B51l2rvGIUSZgUko7rUWJ3pPiiRd/lpRfe+LbolmJkXQUS+palueTcv3SeXneiA/SqRf9b3JD8ftXrniu/suzc6KuvuTbbmEqweLlH4t2l+d82q8tNxwSdSNT25NyoRgRgSPmsWDDLt3iMmqCqcpQj+lryUaikUbAxpaNUySFrKa99cYDZdrcI6JiK1PuOQCPAngNwJxzbvWrdQLAnnXOxWAw9AiZFr5zru6cuwXAXgC3A3hnu2btziWiB4joMBEdnlVvJ4PB0B90ZM5zzs0BeBzAHQC2ENGqqrAXwKnAOQ865w455w5NTW1Zz1wNBsMGYU0dn4i2A6g65+aIaBjA3Whu7P0AwCcBfBXAfQAeXqsv54CEmj4SERYztTj2W6UJB2Wn6jfNMZdgXlTNiO0FOJK6da3ides3nvtpUq4sL8o+ePpr2QWqbPCCJhyp+f4LLPqvrmw3V5a8ifD8BSlFLY96xX6FXecQy9kHAI3nvKnvwrHXRd0oMzMOTnnyzuEtW0W7LYzfvzw4IurCRPiayDKbUutiSn5GE++GILM5uZPTst0D8X1XmwHJVynj5Wax4+8G8BA1KWcLAL7unHuEiF4E8FUi+m8AngHwpWxDGgyGfiPLrv4vANza5vPX0dT3DQbDLxn6wKvfFFHi3GtSjGkw8TgaWafO4nCMc58TeJDToic7x8k+qlUvYi+yFNcrVSnPF4v+WHPzN9jYtVo4RVcBvv9KXUb4LTNRf35BqhlLFT92jY09rsYaGPDqSGF2RtRdmTmXlFcafh7cxAgAIzuvScrX3Pw+UbeVeSEOjIwlZS3ax826IcTMfpHvROcU/l2jkwg/6RnY/nvabBW5gA7vo/nqGww5hC18gyGH6LGo7+Aa9VYpLFxp8U/u9zOxP8Ug4X/HCoqTmtNoO0G2oQjt+LES9bnoNTjtvdtOHX1FtKsyMWzUSa674TIj8yhLNcBV/Y58veKDasqq3cCg73Pn7mlRd/ykF9trM37Hv6pE/dExL36XirL/8gDP2svuY13SgS8c9+nBnj/9hqib3O4JR7a97QZfPnC9nMf0Dj9WWQULcWQmoYiIvBmrNqL7uAdeuFKet3lEhvbGNxhyCFv4BkMOYQvfYMghem7Oc42mnthIEWXwotbx20dO6fRXguveaR2fefwVuCegNhDyiWj939ddf+t7kvKJY2+KVsdYKuzJEam3To56M9rEkDQDlvjYjHy0prz/yowoc4wRewDA7l3+et48eSEpLy1LIo6tK15fHx4eEnXDw37OJba/UFP3Y6Xm+yiWZN3SRe/BfXr+bFKeefmwaDe6zcd2je+WxCRT+6717dheABXknoQwh2U2+HYXHRpvyU/K7rsnHRuzkcTECEezwN74BkMOYQvfYMgheivqO4dGvSm3aiFaNNOifjCvUJgTv6Gib4iZ98jxQB8dpRPpnpUnp7wZ7Z7f/qRo94Pvfispv/Tsz0Xd3KL3upsekWxlY2U/l1FmshtQqbw4OUbBSbF3eMgH4+y/xgfRXJiZF+3m5v1xrS7VgErVi/4FJlZr6XW56oOKuAkQAArMa3BgjGX+ZecAQOWMV4sun3xV1J19zucZ2HOrTwe2452SOKRQ4vkOukSM3z92mkg3xk3GyiM0Zi8M1kW4JzPPsD3sjW8w5BC28A2GHMIWvsGQQ/RUx3dgLreaS4GTHUZyqDnX9mMA0rynI5u4mU6UU+28blrQupjo3388Pjkpmt398d9MysPDkqDi6Sd+mJQXFiQn/tig/x2eZDr+9PiYbMf2LwaLcpLDY57l6Npr35WUr5w/J9rNzvnj1xQRx5lT3vzmmI5PKjqP50EcHZVEH3W2IdCo+/L4uOxjmO0NlFXEWeWSdz8+9pPvJ+XB8SnRbuqadyCErBnstGk41FTvP0mKkfCXU+j/kT7ksMo0GeGB7RT2xjcYcghb+AZDDtF7z70sdgjNRRc6iJpFwkQFwjyoJ1TwhsaUSZACv5NqHkNMvP9n93xU1rF0VT/67ndE3cycF/0XSl4EXqxI171tE/54dECa4uol3//chUu+XVGK4vt2eY+5WkOa2F4/+azvg3H4Q4n6Q0N+rOGKjNxbZu6Gy1V/TxcryoNwYjQpjw9J86ar+2ezxMyPJ198RrTbwjz8qKCfUUZOv+CXTB+q76bM2x5sx3hgUFB20SD3f+o7HEkl14EJErA3vsGQS9jCNxhyiD5w7rWHINtQolBD+PmxnfUI75iOrxFthVikdlg5711Bb8225zyrN6Qo3mCEFVrQfPd7WHDPm0dF3Zsvvuj7ZDx+F+Ylr94yE6unx6TVYHHRc+Q1Fvwcd23fIdpNTnixuqSCXkZYn4tsHvo663Umzleqom72kp9zjYn9msdwsMytKPJu1RkxSaXqVYSLJ2VQ1MoVP9bQmAxakujS3y1IFQ7xPaCI2ydFvO4itBzZ5hRv2Rb2xjcYcghb+AZDDmEL32DIIfqm40e1La2/COubP2gofVF4/6nU0sGwuyjbQfh30TW42U+asriO31AElXyr4eq3Xyvqzr/pCSsbV3wUX0WZ87geTw3JuV+e9I/08hXv+TY7J/tYXvGegcfOnhF1FTbHEksHVqrJaymyqLjFK9ILseS8qW+F7ZU0RiUxydCAPy4PSkIQYiSgVTanpcuXRLulBW/qi+v4rG91HKXliFa2191jOnfXXncbyL2Z+Y3fSpX9DBE90jo+QERPEtErRPQ1IhpYqw+DwfDWQCei/qcBHGHHXwDw5865gwBmAdy/kRMzGAybh0yiPhHtBfBxAP8dwH+gZkTDhwH8XqvJQwD+BMAXN2RW2kwnRHgmNjZUCqo6E7+19x+nsxP9xYbWGUm5eM9MVMrMxY8bSuVoMNPWoJMmsOt2+2CfAfLl+cuy3bkLl5Py0ooUv8/MLvh5sKzkjYJsd+mc98g7NS/FdM6t1wjce0AGEu2aGhd1owP+nbJ12l8LVw+affr7UV2RHoRFZurjuQWWFXHIwkUfVDTFMvgCMiAmnLNXIea4p82/IuVVhC8vsyKQ1eSYiiDLeF4TWd/4fwHgj+GJc7YCmHPOrX6bTgDY0+5Eg8Hw1sOaC5+IfgPAOefc0/zjNk3b/lQR0QNEdJiIDs/NzbdrYjAYeowsov6dAH6LiD4GYAjABJoSwBYiKrXe+nsBnGp3snPuQQAPAsD11x1cL1WYwWDYAKy58J1znwPwOQAgorsA/Efn3O8T0V8D+CSArwK4D8DDHY0cyfKrdUmhZ3IzWl2nmfbHWndvCBdbrqursdhxXe8TcN2duas2lI7P/YW1i+rs6ZNJefH4y6Juz1Zvihoa9tF025eljr+N6dOnzs2JujfPelPXc8c92cb4qHTtrbIbXtA89ey+FgSBqbyWcskLjLfeKE2TW8baG3kc5OcrPFWhastzBK6UvdmSlpdFu/kz3oXX3XCb7ITaa/Y60jKytZNd7RZkMmoasf4C6nkskrWfRByfQXOj71U0df4vrW8qBoOhV+jIgcc59ziAx1vl1wHcvvFTMhgMm43ep8luyS9pSb99mqzmB2Hzm+5lFVoUknxl7bnz0t1F1IB6e9MeINUR7UE4c+Z4Uh6oS/NVo+7F8aXLXpwlZQKbmvTkFRMTUoSfYsQWr544n5Rnl+Q8VhjJxaAi2KgxNWZ4yIvbZeWjtVzx8z97TnL6jQxexebrzXm1hrzhtMLmpZ5Fnc2xWuditJzvpbNefVpZkpGMA6OcrzArL71OzSZI92NNg4iZBEOTSX83s6eWXwvmq28w5BC28A2GHKJ/RBxRjzmFgGheKGiRjKV7inBocMmtoem12TGpXWy5a8s891Lz5efIOa4wworFS9JjzjHnunLBP5rygBTTi2w3XWe6PXDNzqS8c7u3Epw8OyPaHT3jfSrmpcaBOhOlJxhtNqf/BgBU/bWcvTArqgbKXi1YWfF3aHhEqiZllh6sUCyLOh7gxMlZtCVmad5fG/fiA4Bto9Kj0HfY/mOgHdcdr4werrMGyGxCyEReGYa98Q2GHMIWvsGQQ9jCNxhyiB6nyaYkmsmlEmWH0w+FApsKUB5nhUBD1YWI1FP6YpFNS0f4FcTYvmFN6YSclLKqyCsWLnnPurnZy6JuaclHnY0xbv4RxTc/wvT6Yknq/9xUOT7mTXvXj8q9gH07fOjeqfPS++8Y4+UYGfZ6945tMlVY9Yr3pqssSW+6y8y7bnDJtyuWpR5fZPkDSJGbcl59cY7izq9U/FgXj8t0YNuufjs74hs9mgyz2+i5UMyf2jvKHBsYydMuCGPX57pnb3yDIYewhW8w5BC999wLijlhMYmDggfaI0+lvxJeW6xc1GK6L2s1QJgEY/Ng0v383AVRV1v0ZrQFlXZqccmbx7YxMVdnqS0P+kmWVKBSocYCbMD7l9dSLnmR+6ppafIaZCL33AILjlFeiCU2r+KQ5NIrMRGeZ9VdURyEgsevJNUAx0R6nhqLVGqzOntO54+/IequZSbHYlnOccOhvfp4Vew7vYFcellhb3yDIYewhW8w5BC28A2GHKK3Oj7B68NKf45ZOFzA1KddK7nurl0ahY4olHL528fz5cXoDF05HIHXqHm98sKx10RdkXHwDw1IM925GW/qW1r2frQrK5JckpsIJ1TuvNEhb7ar8bx0kOBEIlWVupoTiZSY2bKiqNOGmFtueUDURuChAAATL0lEQVTqzwPMBMn1fZ3vsMZ0cChCEBcgAeHRg4CMlJyfkVGCFRatN1Lm97sTxTr25QyZATXZK48c1SSurAf2nU67nfM02Xr/aXPINg0Gw68QbOEbDDlEjz33wIg4wuKTFmNCTdNZspk4VdDiVHtRS4tInGMuZZ1hxyXnTU+uJCdSafiou8vz0iuOewOODMnbv33Kk0bUmPhdXVEpqGe9yH35sozw4159I8Ne/B4ZlKI4txAuXVFed5e9msFNZUX1ICZYn5Mq6u5Koz0noU61PVTy0X9FFZ1Xr7NUYUwErtbk/Vhhx+7ygqirLDJRf3wLQhACu3aY43VZOfcj0XMpMT1yFJxHMLw1W9SevfENhhzCFr7BkEP0nohjVSRJSSphMSkl+reQ4tXjHnmpWIr2u/VpWrNIGiTmDdhgRBlFJeoT252uqB1ozh03pAJWto55sXeccd1tVemp+DyOn5SegSdP+2Pu1DcxLkXxoQE//5pSJSrsuMivk6SXYG2E75LLa7my7FUQYmQbJbVzX2G7+jX10KqsbpHx+y1W5HyvcPrxgrzflWXveSjVPShEiDMy8nLEBPiYCO5CB1G9IsritybsjW8w5BC28A2GHMIWvsGQQ/Rex19VfdK2uDVPSbXSKpDQ41OV7IDpqilVjP0WpjYR/LEg+nTyNg6PeZ18YttuUbd0yZviiposlJWHBn2fw8rDjxN27tkhTVTjI96cd+yk59U/cUruBfDbwQkvARV1x3TygaKcx9gA4/4vqtwCLP1VnXkCLisvxAZ792hzHn+e1UY4nfYKP1ZzdBFvN44YTYbMyRA5M84YGxnBta2hFIc/9+oL2bgjU2DItPCJ6CiABQB1ADXn3CEimgbwNQD7ARwF8K+cc7OhPgwGw1sHnYj6H3LO3eKcO9Q6/iyAx5xzBwE81jo2GAy/BFiPqP8JAHe1yg+hmVPvM2udFBO3PMKmkHgoQiRIgon3DSZCaXFKpkuKzzI5R4ns5bIXt/dde4OoO/XGK0m5qu/Fig++mWWc+yoGCCuMm29A8fGNjfrjm264Jinv2i5VglcZsd6FGcn9x7g8MM6y9m6ZlmMV2OHsouxj2/jWpFweZA0Vj57wulOmT/4AVqr+mhuKx5B7NjZK8llUlvx9rDPzYEGpN4WINydXF9rkd9MfBD4P9xH2yAsH86SXyKpHbDZkfeM7AN8joqeJ6IHWZzudc6ebY7rTAHZk7MtgMPQZWd/4dzrnThHRDgCPEtE/ZR2g9UPxAADs3LG9iykaDIaNRqY3vnPuVOv/OQDfQDM99lki2g0Arf/nAuc+6Jw75Jw7tGVyol0Tg8HQY6z5xieiUQAF59xCq/xRAP8VwDcB3Afg863/D2caMYuOH4nO41F9KfIB194sAkTS2en8eMKfV/cf2EOIkCJsv+pqUTc46aWe8yePibqhoj9vbt7rtLMDMvVzmZnYeAQeAEyRj/AbH/PmsanJYdHuxmu9mXFmTvZ/7qKPcKvx9NpKt641vM589V4pzQ0Oe72+UvX3pzwi59tgGwp1peOv1PxxrerLlRUdnefrlpcvibqnHnvEn3fnPUl55979ot3ouDfBFpX+L75nKT5+DxfdiwoTbIhW4sRIQ52WImmbTcvPIurvBPCN1sWXAPwf59x3iOgpAF8novsBvAngU5lGNBgMfceaC9859zqAm9t8fhHARzZjUgaDYXPRc8+9VWtZI2XT4GK6NsWFPJZiAn0EsQgoJkO5lDmFeYExggqdaovz8Q8oLrrrb7k9KZ88cULUzc97kZuTXixckaLtKPOKk8I3UGD8dpybrq7IK7hYXVa8/bum/V5MGV5k3z41LdpNbvMpuvZeJev42OeYKnFxXhJlcE1L3++Vqr+6JSbeX16WnnsVdm0FJQNfeONIUn5y1m9Dbdt/vWh33a3vS8q79u0XdYOMx7Cg8zUIU253vPoC4n5IUNha3XF4nvnqGww5hC18gyGHsIVvMOQQ/cudF9VJNANPe8VH64QxHagRNL9puwjjctc6PtPd6w1uhpJ98GNVha279iTla659h6h7+RfPJuVF5pa7rExIK2zOy3Wp7y4se714ZIBF1umfeHYP9Pyp4cfbMuTNgCX1nhhlvP2czx8AVlZ85N6urd5duKgYeM5d9HFdCyydNgAssyhEHtVXLMl5jJX8PkRB7dkMsLbVhZmkfPq5n4h2F469mpS3KzfrA9f7ve09V+8XdcMjfp9D5GvsMh8eZ5HSEaYukpsvxGwVgr3xDYYcwha+wZBD9FTUd+D+RRGywIipIhbdJ81+7UZvgqdj0qmwuWmuoSLJ6kI8DovKnIs+1T87PnD9TaLuwpnTSZnOexKNWlX2scLSRNdr8rf7ChPT55ioPOCkOY/YPSgpVWKIea4NFL3qcGlZcvjvLk36PnQq7xGvIjQY9//EiPQgFPenIL3uqjV/PM6ISAfLcqzlZa9W6GdREFGZ7Jkp8+bSxZNJ+fj8eVF34bUXk/Kx/VI9e9u7bkvKO/d4L82h4VHRTqRwiwbuhVK9KVU2EJ2XFfbGNxhyCFv4BkMO0Tde/bSkEiPAoLbFNP8+L2vxu30wSF0F6XDxXovptYCKoNtxaTOmSoyMT4q6W+68Oym/8NQTSfnscRnMw1UV0tl+2c4yJ5uo1jVZCOcPlDvtnO9vYsqL2Afevke0276TkXsoj7ZC0fdZLnt1obogM+4ushRXS0syldfIkPd6LPHdeSWm8+c+MKACbNi9WmnwlFzyfgyy/guKm7+w4D3+LhyR7HLnj/oI9Z0H35WU337DraLdtp1XsTlKb04ilk2Y3Uc9R4rJ+q3DjSbiMBgMv0KwhW8w5BC28A2GHKL3On4AQaIMSPNHLDaP7xOkdOt6e087reOLOkUMIXR3Hk2Y0vF55J6cI5+XtsCMTHid+fpDdyblmiLAuHTeE2UqJzYMMnWd68VFknrl2LDPpbd1YkzUXbXNR+ftZpF605Myhx/PpVfVBJhVpltX/H0kdUNc3Z83VFL6OYs05KScSxXJzc+/PYNlSQiKhu/fsWet+FFR5DkZUmmseZ0iI1m4mJSPP/OjpDx7RkZe7jl4Y1Lezcx+ALBlihGTDrBIQLX3Uij4qEwqqKXbetR6bysEe+MbDDmELXyDIYfog6i/yv8d8dzTEBx54TRCTpjblJkuRPSR4vdrBOu4GBVpJklFGroq4l3Ijrmp77pfe59o9vLP/jEp0+KcqNvCUmhNjPrypBLnOV/+1Jismx73dQOMB7BWlSJ2jd+P1ANkqcKZaF6vSfVpgJkcKwXZySLzyLt8xQfw1JQKNjDA0nUpVaLMzIpjjO+vXtMUJkw9U3WxeBseFOSqPmBq4c1XRLsXznjPwKPbdom6PQeu9eWrr0nKo6NStSozM2BBpRsrtDwDueoUg73xDYYcwha+wZBD2MI3GHKIvpnzYnpTmpiwQybBQK+dfZ6uS0VVrbaKRALG+NXTbpeuXRHjW7aJZgdvfX9SPvrck6Juecnr/GPOm7ZIE47UvXmssix59ZfZt8IxE1tDbVhw/TkWNcndoHWeQX4/lhVf/qVFr+Nz7nzdBTfBLq1IHbfByUIYMcnwkDSVOZagsK4iDfneTkOTlhCvC6fMrix7ktG5YzLP4Mzp40n5tSM+P8H2q/aKdtNb/fdgRO3LrO5zrKgU4iHYG99gyCFs4RsMOUTvRf1VaSiVnpqnrpbiVChtVoqTLBDEp49lliwtbvN0yYpDnZgHmuOfq8F4Gm6lHxQEpx+CdfxatGg7vsV7eu2/6f2i7viRZ5LyydlTSfnSZclnN8FSb20ZHRF1dZbWapRHyCkReJCZ0fQbpMFE51qDlRW5yeJSlZWlmMpNsPw5FdQ8+APQOQ6qTDSvMC2gPCjNYSVGlKG587lN1pXkU+PpvMq8Swp/hxsNlSps2Yv+i6c8ackCMwECwFEe1VeS8y+2Bl+Yl+bdEDK98YloCxH9DRH9ExEdIaL3EdE0ET1KRK+0/k9lGtFgMPQdWUX9/wHgO86569FMp3UEwGcBPOacOwjgsdaxwWD4JUCWbLkTAH4dwL8BAOfcCoAVIvoEgLtazR4C8DiAz6w9ZPv9fJdts1tlJI2oCwUlpguxMdxHgdqL27r/Avcya8iGjBIPBVUXzagamGOabtyXh0bk7u6+G9+TlM+f8AQeF0++KtrNzXg+u4VFKWJfYSmqpsY9dxwnxgCAAbarL/fIAS7RV9iO/Pxlyds3w46XlddZlQfV8L7VznqRz6Mo72qRPbM6e6BLNdkHJ+Ioqu9Okatu6l3pBnzdCrzYr60cBRZUo1WmCksVtsL4FWt1RTiyzIhElP6XWFw20HPvbQDOA/jfRPQMEf2vVrrsnc650wDQ+r8j04gGg6HvyLLwSwBuA/BF59ytABbRgVhPRA8Q0WEiOjw/f2ntEwwGw6Yjy8I/AeCEc27VU+Rv0PwhOEtEuwGg9f9cu5Odcw865w455w5NTk60a2IwGHqMNXV859wZIjpORNc5514C8BEAL7b+7gPw+db/h9fqi8D0Wm0Di0SticOoy1/YjMZ1ooLjSrgmwwyb87jqxyMBUxZBYZaLXEwq+i8b93+MLJSTNWzdcyApj03LiLBLFzyH//y546JuYdZLZheZ99yk0vGHmDmvqB4MJ+ZYZCavhYrUQZeqTI9VJrDhQe95yCPwiuqGc886/SwK/LlLO65oV2P3uKDYTfjXQJtWuVHNlVgfqv9qlZnwtJNj2S/DQoGZH1ek2Y9/XUpFvRCo7fxCyGrH//cAvkxEAwBeB/Bv0ZQWvk5E9wN4E8CnMvZlMBj6jEwL3zn3cwCH2lR9ZGOnYzAYeoG3DBGHoM7XXn1cdOYec6pnKV1pr7uA91/KO699Oz0PUY6I8ynCERcR4TlXHxPna8r0JHgBY3U8OEYRN4xt8wEgI9O7RV1tiXmSXfTefyfm5TYOzXuO/ILmLmTXsszMeVXllcktWxMjki+vzDj3ikyGLSp5tsyuLcarL/MuQILpcZonscoal5UaUGAedCX+nVAmQe4NmMrXwJ51gZn2ysrsJ85LeYs226Y8UQMwX32DIYewhW8w5BC28A2GHKIPOn4SnhdsEaPG4DqMdovk6o3W4TJaOUQfKXWJEyuKlNxhs5zTpJ88h1/KTOfrakwvTqXhrvN9At1/+9wCOkdAjfep7lV5xPP7bx2f9v3VJNnm0qWZpLx46aKoazCizAIz7VXn1D6B8+7B2gTGHyLfyyiV5de2zMhCtAs2vzQe/ZjiJaHwBkCZ6//qvAKrK7LwPKfdybn+Xw9H7pW587N6Lvy56zWySrSS1Zxnb3yDIYewhW8w5BAU40rb8MGIzgM4BmAbgAs9G7g93gpzAGweGjYPiU7ncY1zbvtajXq68JNBiQ4759o5BOVqDjYPm0e/5mGivsGQQ9jCNxhyiH4t/Af7NC7HW2EOgM1Dw+YhsSnz6IuObzAY+gsT9Q2GHKKnC5+I7iWil4joVSLqGSsvEf0lEZ0joufZZz2nByeifUT0gxZF+QtE9Ol+zIWIhojop0T0bGsef9r6/AARPdmax9da/AubDiIqtvgcH+nXPIjoKBE9R0Q/J6LDrc/68R3pCZV9zxY+ERUB/E8A/wLADQB+l4hu6NHwfwXgXvVZP+jBawD+yDn3TgB3APiD1j3o9VwqAD7snLsZwC0A7iWiOwB8AcCft+YxC+D+TZ7HKj6NJmX7Kvo1jw85525h5rN+fEd6Q2XvnOvJH4D3AfguO/4cgM/1cPz9AJ5nxy8B2N0q7wbwUq/mwubwMIB7+jkXACMAfgbgvWg6ipTaPa9NHH9v68v8YQCPoBlW0Y95HAWwTX3W0+cCYALAG2jtvW3mPHop6u8BwMndTrQ+6xf6Sg9ORPsB3ArgyX7MpSVe/xxNktRHAbwGYM45txpR06vn8xcA/hg+bGZrn+bhAHyPiJ4mogdan/X6ufSMyr6XC79d3FAuTQpENAbgbwH8oXOuL5zjzrm6c+4WNN+4twN4Z7tmmzkHIvoNAOecc0/zj3s9jxbudM7dhqYq+gdE9Os9GFNjXVT2naCXC/8EgH3seC+AU4G2vUAmevCNBhGV0Vz0X3bO/V0/5wIAzrk5NLMg3QFgCxGtxrj24vncCeC3iOgogK+iKe7/RR/mAefcqdb/cwC+geaPYa+fy7qo7DtBLxf+UwAOtnZsBwD8DoBv9nB8jW+iSQsOZKQHXy+oSSbwJQBHnHN/1q+5ENF2ItrSKg8DuBvNTaQfAPhkr+bhnPucc26vc24/mt+Hv3fO/X6v50FEo0Q0vloG8FEAz6PHz8U5dwbAcSK6rvXRKpX9xs9jszdN1CbFxwC8jKY++Z97OO5XAJwGUEXzV/V+NHXJxwC80vo/3YN5fABNsfUXAH7e+vtYr+cC4CYAz7Tm8TyA/9L6/G0AfgrgVQB/DWCwh8/oLgCP9GMerfGebf29sPrd7NN35BYAh1vP5v8CmNqMeZjnnsGQQ5jnnsGQQ9jCNxhyCFv4BkMOYQvfYMghbOEbDDmELXyDIYewhW8w5BC28A2GHOL/AzIb2WTyeD8aAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "index=8\n",
    "plt.imshow(X_train_orig[index])\n",
    "print(\"y=\"+str(np.squeeze(Y_train_orig[:,index])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1080, 64, 64, 3), (120, 64, 64, 3))"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train_orig.shape,X_test_orig.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1, 1080), (1, 120))"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_train_orig.shape,Y_test_orig.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 清洗数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train=X_train_orig/255\n",
    "X_test=X_test_orig/255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_one_hot(Y, C):\n",
    "    Y = np.eye(C)[Y.reshape(-1)].T\n",
    "    return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_train = convert_to_one_hot(Y_train_orig, 6).T\n",
    "Y_test = convert_to_one_hot(Y_test_orig, 6).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1080, 64, 64, 3), (120, 64, 64, 3))"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape,X_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1080, 6), (120, 6))"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_train.shape,Y_test.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create placeholders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_placeholders(n_H0, n_W0, n_C0, n_y):\n",
    "    X = tf.placeholder(tf.float32, shape=(None, n_H0, n_W0, n_C0))\n",
    "    Y = tf.placeholder(tf.float32,shape=(None,n_y))    \n",
    "    return X, Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "X,Y=create_placeholders(64,64,3,6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor 'Placeholder:0' shape=(?, 64, 64, 3) dtype=float32>,\n",
       " <tf.Tensor 'Placeholder_1:0' shape=(?, 6) dtype=float32>)"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X,Y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initialize_parameters():    \n",
    "    tf.set_random_seed(1)\n",
    "    W1 =tf.get_variable('W1',[4,4,3,8],initializer=tf.contrib.layers.xavier_initializer(seed = 0))\n",
    "    W2 = tf.get_variable('W2',[2,2,8,16],initializer=tf.contrib.layers.xavier_initializer(seed = 0))\n",
    "    parameters = {\"W1\": W1,\n",
    "                  \"W2\": W2}    \n",
    "    return parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.00131723  0.1417614  -0.04434952  0.09197326  0.14984085 -0.03514394\n",
      " -0.06847463  0.05245192]\n",
      "[-0.08566415  0.17750949  0.11974221  0.16773748 -0.0830943  -0.08058\n",
      " -0.00577033 -0.14643836  0.24162132 -0.05857408 -0.19055021  0.1345228\n",
      " -0.22779644 -0.1601823  -0.16117483 -0.10286498]\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "with tf.Session() as sess_test:\n",
    "    parameters=initialize_parameters()\n",
    "    init=tf.global_variables_initializer()\n",
    "    sess_test.run(init)\n",
    "    print(parameters['W1'].eval()[1,1,1])\n",
    "    print(parameters['W2'].eval()[1,1,1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Forward propagation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_propagation(X, parameters):\n",
    "    W1 = parameters['W1']\n",
    "    W2 = parameters['W2']\n",
    "\n",
    "    Z1 = tf.nn.conv2d(X,W1, strides = [1,1,1,1], padding = 'SAME')\n",
    "\n",
    "    A1 = tf.nn.relu(Z1)\n",
    "\n",
    "    P1 = tf.nn.max_pool(A1, ksize = [1,8,8,1], strides = [1,8,8,1], padding = 'SAME')\n",
    "\n",
    "    Z2 = tf.nn.conv2d(P1,W2, strides = [1,1,1,1], padding = 'SAME')\n",
    "\n",
    "    A2 = tf.nn.relu(Z2)\n",
    "\n",
    "    P2 = tf.nn.max_pool(A2, ksize = [1,4,4,1], strides = [1,4,4,1], padding = 'SAME')\n",
    "\n",
    "    P2 = tf.contrib.layers.flatten(P2)\n",
    "\n",
    "    Z3 = tf.contrib.layers.fully_connected(P2, num_outputs = 6, activation_fn=None)\n",
    "\n",
    "    return Z3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Z3 = [[ 1.4416984  -0.24909666  5.450499   -0.2618962  -0.20669907  1.3654671 ]\n",
      " [ 1.4070846  -0.02573211  5.08928    -0.48669922 -0.40940708  1.2624859 ]]\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "with tf.Session() as sess:\n",
    "    np.random.seed(1)\n",
    "    X, Y = create_placeholders(64, 64, 3, 6)\n",
    "    parameters = initialize_parameters()\n",
    "    Z3 = forward_propagation(X, parameters)   \n",
    "    init = tf.global_variables_initializer()\n",
    "    sess.run(init)\n",
    "    a = sess.run(Z3, {X: np.random.randn(2,64,64,3), Y: np.random.randn(2,6)})\n",
    "    print(\"Z3 = \" + str(a))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_cost(Z3, Y):\n",
    "    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=Z3, labels=Y))   \n",
    "    return cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-58-c5ef55c7beaa>:2: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See `tf.nn.softmax_cross_entropy_with_logits_v2`.\n",
      "\n",
      "cost = 4.6648693\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "with tf.Session() as sess:\n",
    "    np.random.seed(1)\n",
    "    X, Y = create_placeholders(64, 64, 3, 6)\n",
    "    parameters = initialize_parameters()\n",
    "    Z3 = forward_propagation(X, parameters)\n",
    "    cost = compute_cost(Z3, Y)\n",
    "    init = tf.global_variables_initializer()\n",
    "    sess.run(init)\n",
    "    a = sess.run(cost, {X: np.random.randn(4,64,64,3), Y: np.random.randn(4,6)})\n",
    "    print(\"cost = \" + str(a))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_mini_batches(X, Y, mini_batch_size = 64, seed = 0):\n",
    "    \n",
    "    m = X.shape[0]  \n",
    "    mini_batches = []\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    permutation = list(np.random.permutation(m))\n",
    "    shuffled_X = X[permutation,:,:,:]\n",
    "    shuffled_Y = Y[permutation,:]\n",
    "\n",
    "    num_complete_minibatches = math.floor(m/mini_batch_size) \n",
    "    for k in range(0, num_complete_minibatches):\n",
    "        mini_batch_X = shuffled_X[k * mini_batch_size : k * mini_batch_size + mini_batch_size,:,:,:]\n",
    "        mini_batch_Y = shuffled_Y[k * mini_batch_size : k * mini_batch_size + mini_batch_size,:]\n",
    "        mini_batch = (mini_batch_X, mini_batch_Y)\n",
    "        mini_batches.append(mini_batch)\n",
    "    if m % mini_batch_size != 0:\n",
    "        mini_batch_X = shuffled_X[num_complete_minibatches * mini_batch_size : m,:,:,:]\n",
    "        mini_batch_Y = shuffled_Y[num_complete_minibatches * mini_batch_size : m,:]\n",
    "        mini_batch = (mini_batch_X, mini_batch_Y)\n",
    "        mini_batches.append(mini_batch)\n",
    "    \n",
    "    return mini_batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(X_train, Y_train, X_test, Y_test, learning_rate = 0.009,\n",
    "          num_epochs = 100, minibatch_size = 64, print_cost = True):\n",
    "    \n",
    "    ops.reset_default_graph()                       \n",
    "    tf.set_random_seed(1)                        \n",
    "    seed = 3                                          \n",
    "    (m, n_H0, n_W0, n_C0) = X_train.shape             \n",
    "    n_y = Y_train.shape[1]                            \n",
    "    costs = []             \n",
    "\n",
    "    X, Y = create_placeholders(n_H0, n_W0, n_C0, n_y)\n",
    "\n",
    "    parameters = initialize_parameters()\n",
    "\n",
    "    Z3 = forward_propagation(X, parameters)\n",
    "\n",
    "    cost = compute_cost(Z3, Y)\n",
    "\n",
    "    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)\n",
    "\n",
    "    init = tf.global_variables_initializer()\n",
    "\n",
    "    with tf.Session() as sess:\n",
    "\n",
    "        sess.run(init)\n",
    "\n",
    "        for epoch in range(num_epochs):\n",
    "\n",
    "            minibatch_cost = 0.\n",
    "            num_minibatches = int(m / minibatch_size)\n",
    "            seed = seed + 1\n",
    "            minibatches = random_mini_batches(X_train, Y_train, minibatch_size, seed)\n",
    "\n",
    "            for minibatch in minibatches:\n",
    "\n",
    "                (minibatch_X, minibatch_Y) = minibatch\n",
    "\n",
    "                _ , temp_cost = sess.run([optimizer, cost], feed_dict={X:minibatch_X, Y:minibatch_Y})\n",
    "                \n",
    "                minibatch_cost += temp_cost / num_minibatches\n",
    "\n",
    "            if print_cost == True and epoch % 5 == 0:\n",
    "                print (\"Cost after epoch %i: %f\" % (epoch, minibatch_cost))\n",
    "            if print_cost == True and epoch % 1 == 0:\n",
    "                costs.append(minibatch_cost)\n",
    "        plt.plot(np.squeeze(costs))\n",
    "        plt.ylabel('cost')\n",
    "        plt.xlabel('iterations (per tens)')\n",
    "        plt.title(\"Learning rate =\" + str(learning_rate))\n",
    "        plt.show()\n",
    "\n",
    "        predict_op = tf.argmax(Z3, 1)\n",
    "        correct_prediction = tf.equal(predict_op, tf.argmax(Y, 1))\n",
    "\n",
    "        accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))\n",
    "        print(accuracy)\n",
    "        train_accuracy = accuracy.eval({X: X_train, Y: Y_train})\n",
    "        test_accuracy = accuracy.eval({X: X_test, Y: Y_test})\n",
    "        print(\"Train Accuracy:\", train_accuracy)\n",
    "        print(\"Test Accuracy:\", test_accuracy)\n",
    "                \n",
    "        return train_accuracy, test_accuracy, parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cost after epoch 0: 1.921332\n",
      "Cost after epoch 5: 1.904156\n",
      "Cost after epoch 10: 1.904309\n",
      "Cost after epoch 15: 1.904477\n",
      "Cost after epoch 20: 1.901876\n",
      "Cost after epoch 25: 1.784078\n",
      "Cost after epoch 30: 1.681051\n",
      "Cost after epoch 35: 1.618206\n",
      "Cost after epoch 40: 1.597971\n",
      "Cost after epoch 45: 1.566706\n",
      "Cost after epoch 50: 1.554487\n",
      "Cost after epoch 55: 1.502187\n",
      "Cost after epoch 60: 1.461036\n",
      "Cost after epoch 65: 1.304490\n",
      "Cost after epoch 70: 1.201760\n",
      "Cost after epoch 75: 1.163242\n",
      "Cost after epoch 80: 1.102885\n",
      "Cost after epoch 85: 1.087105\n",
      "Cost after epoch 90: 1.051911\n",
      "Cost after epoch 95: 1.018554\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEWCAYAAACJ0YulAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8HWXZ//HPlb1Z22brvu90YQnQQoFiUUBABMqiIgoosii4PY/gBrjwAwQB5UFAhIJoRRERUZCdAm2BllLovreUtlm6JmnTbNfvj5mGtCRtWnIySc73/XodkzNzn5lrmHq+mXtm7jF3R0REBCAh6gJERKT9UCiIiEgDhYKIiDRQKIiISAOFgoiINFAoiIhIA4WCdEpm9oyZfSXqOkQ6GoWCtCozW21mJ0Vdh7uf6u4PR10HgJm9YmZfa4P1pJrZg2a23cw2mtl399P+O2G7beHnUhvNG2BmL5vZDjNb3Hifhuu5w8zWm9kWM7vHzJJjuW3SdhQK0uGYWVLUNezWnmoBbgCGAv2BE4H/NbNTmmpoZicD1wKTgQHAIODGRk2mAXOBXOBHwONmlh/OuxYoAkYDw4DDgR+37qZIZNxdL71a7QWsBk5qZt7pwLvAVmAGMLbRvGuBFUA5sBA4q9G8rwJvAHcAm4FfhNNeB24DtgCrgFMbfeYV4GuNPr+vtgOB6eG6XwD+D3i0mW2YBKwDfgBsBP4IdAOeBkrD5T8N9Anb/xKoA6qACuDucPoI4Plwe5YA57XCf/sPgc80ev9z4C/NtP0zcFOj95OBjeHvw4BdQFaj+a8Bl4e/zwbObTTvi8AHUf/b06t1XjpSkDZhZocDDwLfIPjr8z7gqUZdFiuA44Acgr9YHzWzno0WcTSwEigg+KLdPW0JkAfcCvzBzKyZEvbV9s/AW2FdNwBf3s/m9AC6E/xFfhnBEfdD4ft+wE7gbgB3/xHBF+o33T3T3b9pZhkEgfDncHu+ANxjZoc0tbKwe2ZrM6/3wjbdgF7AvEYfnQc0ucxw+t5tC80sN5y30t3Lm1mWhS8ave9jZjnNrEs6EIWCtJWvA/e5+5vuXudBf/8uYDyAu//N3de7e727PwYsA45q9Pn17v5bd691953htDXu/nt3rwMeBnoChc2sv8m2ZtYPOBL4qbtXu/vrwFP72ZZ64Hp33+XuO919k7v/3d13hF+kvwRO2MfnTwdWu/tD4fa8A/wdmNJUY3e/0t27NvMaGzbLDH9ua/TRbUBWMzVkNtGWsP3e8/Ze1jPANWaWb2Y9gKvD6enNbrF0GO2pP1Q6t/7AV8zsW42mpRD8dYuZXQR8l6B/G4IvprxGbT9oYpkbd//i7jvCP/wzm2i3r7Z5wGZ337HXuvruY1tK3b1q9xszSyfo2jqFoCsJIMvMEsMQ2lt/4Ggz29poWhJBV9TBqgh/ZhN0Ve3+vbzp5lSE82nUlrD93vP2XtYvga4EXYG7gN8DhwElB1m7tCM6UpC28gHwy73+yk1392lm1p/gi+WbQK67dwXms2cXRayG890AdA+/2HfbVyA0Vcv3gOHA0e6eDRwfTrdm2n8AvLrXf4tMd7+iqZWZ2b1mVtHMawGAu28Jt2Vco4+OAxY0sw0Lmmhb7O6bwnmDzCxrr/m717XT3b/p7r3dfRCwCZjTTABKB6NQkFhINrO0Rq8kgi/9y83saAtkmNlp4RdPBsEXZymAmV1McGVLzLn7GoITpzeYWYqZTQDOOMDFZBGcR9hqZt2B6/eaX0xwdc9uTwPDzOzLZpYcvo40s5HN1Hh5GBpNvRqfM3gE+LGZdTOzEQRddlObqfkR4FIzGxWej/jx7rbuvpTgKOD6cP+dBYwl6OLCzHqbWa9wP44HftLENksHpVCQWPgPwZfk7tcN7j6b4EvqboIrdJYTXBWEuy8EbgdmEnyBjiG42qitfAmYQPAX7y+Axwi6RVrqTqALUAbMAp7da/5dwJTwmv7fhOcdPgNcAKwn6Nq6BUjlk7me4IT9GuBV4Ffu/iyAmfULjyz6AYTTbwVeDtuvYc8v9gsILjvdAtwMTHH30nDeYIKrxyoJzs9c6+7PfcLapZ0wdz1kR6QxM3sMWOzu+utX4o6OFCTuhV03g80sIbzZ60zgyajrEomCrj4SCe47eILgPoV1wBXuPjfakkSioe4jERFpoO4jERFp0OG6j/Ly8nzAgAFRlyEi0qHMmTOnzN3z99euw4XCgAEDmD17dtRliIh0KGa2piXt1H0kIiINFAoiItJAoSAiIg0UCiIi0kChICIiDRQKIiLSQKEgIiIN4iYUyip2ccNTC9hVq+eAiIg0J25C4c2Vm5k6YzXf/9t71NdrvCcRkaZ0uDuaD9ZpY3uydvMIbnl2MfmZqfzk9JGEz+kVEZFQ3IQCwOUnDKJ4exUPvrGK/KxULprQn7TkRBITFA4iIhBnoWBm/PT0UZRW7OKWZxdzy7OLAUhJSqBLciJpyQmkJSeS0OgIwgySEowEM8yM3UONu4Pv9Tx2w3CcunoP5++eHvyPQbgcPva53dPq6p06D5ZR7059fTA9KdFISUwgOTGB5KQEUhKNxARjZ3Ud23bWUF5VS1pyInmZKeRmplKYnUrPnC706tqFkT2zGNkjmwSFn4jsR1yFAkBCgvHr88YxaVg+myurqaqpZ0dNLbtq6tlZXUdVbR27HzHhEH4xO7WNzkMYQVg0/jJv/FiKxAQjIcEwPgqG+jAlnOCLvnEwuIfzw882vIKVYBi19fVU1wavmnqnprae2vp6uqan0C83g6y0JKqq6yirrKZ4exXvrdtKWUV1wzryMlOYOCSPI/p3Y3B+JoMLMinISlUXmojsIe5CASA1KZFzi/pGXUbMVdXUsX7rTt5Zu5XXlpXy2rIynnx3fcP8ySMKeOArRQoGEWkQl6EQL9KSExmUn8mg/EymHNEHd2fj9ipWllby3IKNPDxzDa8vL+O4ofsdYl1E4kTcXJIqwTmVnjldOHZIHj88bSQ9c9K464Vl6JGsIrKbQiFOpSYlcuWJQ5i9ZgtvLN8UdTki0k4oFOLYeUV96JmTxp0vLNXRgogACoW4lpqUyJWTButoQUQaKBTi3HlH9qVHdhp3v7ws6lJEpB1QKMS51KREzjy0F++s2UptXX3U5YhIxBQKwtDCLKrr6lmzeUfUpYhIxBQKwrDCTACWFZdHXImIRE2hIAwpCEJhaXFFxJWISNQUCkJ6ShJ9u3dhqY4UROKeQkEAGFaQxTIdKYjEPYWCAMHJ5pVlFdToCiSRuKZQECA42VxT56zZVBl1KSISIYWCADCsMAvQyWaReKdQEAAG52dihk42i8Q5hYIA0CUlkb7d0llWoiMFkXgWs1AwswfNrMTM5jczP8fM/mVm88xsgZldHKtapGWGFWbqBjaROBfLI4WpwCn7mH8VsNDdxwGTgNvNLCWG9ch+DC3MYlVZpa5AEoljMQsFd58ObN5XEyDLggcEZ4Zta2NVj+zf7iuQVpfpCiSReBXlOYW7gZHAeuB94Bp3b/JPVDO7zMxmm9ns0tLStqwxrgwt0BVIIvEuylA4GXgX6AUcCtxtZtlNNXT3+929yN2L8vP1kPlYGVKQSYKuQBKJa1GGwsXAEx5YDqwCRkRYT9xLS06kX/d0lpUoFETiVVKE614LTAZeM7NCYDiwMsJ6hOBk85srN/O1h99mRWkliQnGHecdypg+OVGXJiJtIJaXpE4DZgLDzWydmV1qZpeb2eVhk58Dx5jZ+8CLwA/cvSxW9UjLHDc0j8rqWtZt2cnInlnsrK7jvPtm8uz8jVGXJiJtwNw96hoOSFFRkc+ePTvqMjo1dye4KAxKy3dx2R9nM3ftVn742RFcdvzgiKsTkYNhZnPcvWh/7XRHs3zM7kAAyM9KZdrXx3PKIT246T+L+UCP7BTp1BQKsl9pyYl881NDAHhv3baIqxGRWFIoSIsMLcwkOdF4/0OFgkhnplCQFklNSmRYYRYL1isURDozhYK02OheOcz/cBsd7eIEEWk5hYK02Og+OWzZUcOHW3dGXYqIxIhCQVpsdK9gFJL5H26PuBIRiRWFgrTYyJ7ZJCaYziuIdGIKBWmxtOREhhZkMl9XIIl0WgoFOSCH9Mrh/Q+362SzSCelUJADMrp3NmUVuygp3xV1KSISAwoFOSCjewejpaoLSaRzUijIARnVMxszdGezSCelUJADkpGaxKC8DF2WKtJJKRTkgI3unaPLUkU6KYWCHLAxvXPYsK2Kjduqoi5FRFqZQkEO2KTh+SQnGjc8tUCXpop0MgoFOWBDCrL4/meG8+yCjTz29gdRlyMirUihIAfl68cN4tghudz4r4UsL6mIuhwRaSUKBTkoCQnGr887lLTkBK75y1yqaur2mF9WsYtr/jKXRRt0lZJIR6JQkINWmJ3GrVPGsWD9dq54dE5DMFTsquXih97mn++u5zcvLou4ShE5EAoF+UQ+PaqQm84aw8tLSvnGH+dQXlXDFY/OYeGG7Rw9sDvPLSymeLuuUhLpKBQK8ol98eh+3HLOGKYvK+X4W1/mtWVl/L+zx3DLOWOpq3edjBbpQBQK0irOP7Ift54zlvKqWv73lOGcV9SXAXkZHDc0j2lvraW2rj7qEkWkBRQK0mrOLerL/BtP5spJQxqmfenofmzYVsUrS0ojrExEWkqhIK0qLTlxj/eTRxZSkJXKn95cE1FFInIgFAoSU8mJCVxwZF9eWVrK6rLKfbbdUV3Lrtq6fbYRkdhKiroA6fwuOKofv3t1BZN//Srj+uQwcWg+wwuzKMhOJS8zlffWbeVf8zYwfWkp4wfn8sglR0VdskjcUihIzPXq2oUnrzqWZ97fyGvLy7j7pWXU7zVkUo/sNI4e1J3pS0t5Y3kZxw7Ji6ZYkThnHW1As6KiIp89e3bUZcgnsL2qhg+37KSkfBel5bvon5vOEf26UV1Xz4m3vUJhdhr/uPIYzCzqUkU6DTOb4+5F+2unIwVpc9lpyWT3TGZkzz2npyUkcvXkoVz3xPu8uKiEk0YVRlOgSBzTiWZpV6Yc0YcBuenc9twS6vfuYxKRmFMoSLuSnJjAdz49jMUby3li7odRlyMSdxQK0u6cMbYXI3tm8/2/zePkO6Zz+3NLeHPlJkq2V+mhPiIxFrMTzWb2IHA6UOLuo5tpMwm4E0gGytz9hP0tVyea48PmymqenPsh/12wkbdXb264Wik9JZEj+nfjprPG0Ld7erRFinQgLT3RHMtQOB6oAB5pKhTMrCswAzjF3deaWYG7l+xvuQqF+LOpYhfvf7iNtZt3sLK0kr/PWQcGN589ltPG9tz/AkQk+quP3H26mQ3YR5MvAk+4+9qw/X4DQeJTbmYqk4YXNLy/dOJAvjVtLlf9+R2eeKeA7hkpONAlOZFD+3blqIHd6dOtC/UeHHG4OwXZadFtgEgHEuUlqcOAZDN7BcgC7nL3R5pqaGaXAZcB9OvXr80KlPapb/d0/nb5BO54fin/mPsh7mAG5VW1/HFWMMZSZmoSldW17D4QvvnsMVxw1Ef/djZXVnPH80v56rEDGJyfGcVmiLRLMb15LTxSeLqZ7qO7gSJgMtAFmAmc5u5L97VMdR9Jc+rrnaUl5by9egvLi8vJSU8hLzOF/y7YyKyVm3n44qOYODSPLZXVfPGBN1m0YTvj+uTwxJXHkpigG+Wkc4u8+6gF1hGcXK4EKs1sOjAO2GcoiDQnIcEY0SObET2y95h+1mG9mfK7mVzxpzk89NUjuf6pBaworeCrxwxg6ozVPDxjNZdMHBhR1SLtS5SXpP4TOM7MkswsHTgaWBRhPdJJZaUl8+DFR5KWnMiUe2eyrLiC+758BNefMYoTh+dz23NL+GDzjqjLFGkXYhYKZjaNoEtouJmtM7NLzexyM7scwN0XAc8C7wFvAQ+4+/xY1SPxrXfXLvzhK0WM6pnN7y48nBOHF2Bm/OKsMQD86Mn5ugdCBA2IJ8JDb6zixn8t5KIJ/fnep4eTk54cdUkira4jnFMQaRcumjCAFaUV/HHWGp6at55rJg/lC0f1+9hT5ETigY4UREIL12/npv8s4vXlZaQmJXDUwO5MHJLH2D5d6Z+bTo/sNBJ0lZJ0UJHf0RwrCgWJJXdn1srNvLComNeWlbK0uKJhXkpSAoPzMxnVM5tRvbKZNDxf9zhIh6FQEGkFJdurWFpcwepNlazZVMnS4goWbthOafkukhKMS48byDWTh5KekkTlrlqeX1hMQoJxxtieekiQtCs6pyDSCgqy0yjITmPi0D0fD7ph207ueH4p9726kqfnbeDw/t14YWExO2vqAHhxUTG3nDNW5yWkw9GRgsgn8Naqzfzkyfls3F7FZ8f05POH9mL2mi3c9twSRvfK4bZzx1Gxq5Y1myoBOGNcL5ITNWK9tD11H4m0IXffo7vohYXFfPuxd6nYVbtHuxE9srjp7DEc3q9bW5cocU6hIBKxVWWVvL6slN7dutA/N4PlJRVc/88FFJdXcdH4/vzwtJGkJql7SdqGzimIRGxgXgYD8zIa3g/Oz+SYwbnc9t8lPDxzDYs3lnP/l4t0s5y0K+rcFGlDWWnJ3HjmaO664FDmrt3KOffO0LhL0q4oFEQicOahvXnk0qMo2V7FWffM4MOtO6MuSQRQKIhEZvygXB6/4hh2Vtfyg8ff22NAvtq6ev7z/gaqwktcRdqKQkEkQsMKs/jhaSN5fXkZj765FoC6eue7f53HlX96hzue3/PxIlU1dVwy9W1eWlwcRbkSBxQKIhH74lH9OG5oHv/vP4tYXVbJD/7+Hk/NW8+A3HQemrGa9Y26lu55ZQUvLS7h1meXaKhviQmFgkjEzIxbzhlLohmfu/t1Hp+zju+cNIxHv3Y0ONz5QnC0sGZTJfe+uoJeOWks3ljO68vLIq5cOiOFgkg70KtrF356xii2V9Vy5aTBXD15CH26pfPlCf15fM46lhWXc+O/FpKcYDz2jQnkZ6Xy+9dWRV22dEItCgUzO7cl00Tk4J1b1JdZ103mf04e3nB39FUnDiEjJYnL/jiHlxaX8O2ThtG3ezpfmdCf6UtLWbKxPOKqpbNp6ZHCdS2cJiKfQI+ctD2Gy+iekcJlxw9iVVklQwsy+eqxAwD40tH9SUtO4A+vr4yoUums9nlHs5mdCnwW6G1mv2k0KxuobfpTItKaLj1uIGs27+CiCf0bBtPrlpHClCP68Ne31/E/J48gPys14iqls9jfkcJ6YDZQBcxp9HoKODm2pYkIQHpKEredO46xfbruMf2SYwdSU1/Pfa+uiKgy6Yz2eaTg7vOAeWb2Z3evATCzbkBfd9/SFgWKSNMG5WdyflFfHnxjFZ8d21Mjr0qraOk5hefNLNvMugPzgIfM7NcxrEtEWuBHp42kR3Ya3//bPN39LK2ipaGQ4+7bgbOBh9z9COCk2JUlIi2RlZbMrVPGsbK0ktv+uyTqcqQTaGkoJJlZT+A84OkY1iMiB2ji0DwuHN+PP7yxirdWbY66HOngWhoKPwP+C6xw97fNbBCwLHZliciBuO7UkRRmpemks3xiLXrIjrv/Dfhbo/crgXNiVZSIHJiM1CROHFHA0/PWU1fvJCbY/j8k0oSW3tHcx8z+YWYlZlZsZn83sz6xLk5EWm78oO6U76pl4frtUZciHVhLu48eIrg3oRfQG/hXOE1E2onxg3IBmLlSA+XJwWtpKOS7+0PuXhu+pgL5MaxLRA5QYXYag/IymLVSJ5vl4LU0FMrM7EIzSwxfFwKbYlmYiBy48YNzeXvVZmrr6qMuRTqolobCJQSXo24ENgBTgItjVZSIHJzxg3Ip31XLAp1XkIPU0lD4OfAVd8939wKCkLghZlWJyEEZP7A7ALNW6kBeDk5LQ2Fs47GO3H0zcFhsShKRg1WQncbg/AyFghy0loZCQjgQHgDhGEgtusdBRNrW+EG5vL16i84ryEFpaSjcDswws5+b2c+AGcCt+/qAmT0Y3tcwfz/tjjSzOjOb0sJaRGQfxg/KpWJXLfN1XkEOQotCwd0fIbiDuRgoBc529z/u52NTgVP21cDMEoFbCIbQEJFWcPQgnVeQg9fSIwXcfaG73+3uv3X3hS1oPx3Y3wXT3wL+DpS0tA4R2beCrDSGFmTy3IKNUZciHVCLQ6G1mVlv4Czg3ha0vczMZpvZ7NLS0tgXJ9LBXTi+P++s3cqbOlqQAxRZKAB3Aj9w9/0+GcTd73f3Incvys/XjdQi+3P+kX3Jy0zl7peXR12KdDBRhkIR8BczW01wM9w9Zvb5COsR6TTSkhP5+nEDeW1ZGXPX6sm50nKRhYK7D3T3Ae4+AHgcuNLdn4yqHpHO5kvj+9M1PZn/09GCHICYhYKZTQNmAsPNbJ2ZXWpml5vZ5bFap4h8JDM1iUuOHcgLi0pYsH5b1OVIB2HuHnUNB6SoqMhnz54ddRkiHcK2nTVMvPklhhRmctu54xicnxl1SRIRM5vj7kX7axflOQURibGcLsn87POHsKy4gpPvmM4NTy1gS2V11GVJO6ZQEOnkzjqsDy9/fxLnHdmXR2au5tS7XmNFaUXUZUk7pVAQiQP5WancdNYYnvrmRGrr6zn/vlksLS6PuixphxQKInFkdO8c/nLZeBIMLrh/lp7nLB+jUBCJM0MKsnjsGxNITUrgkqlvU1ffsS42kdhSKIjEoYF5GVx76gg2bq9i3rqtUZcj7YhCQSROnTAsn8QE46VFGo9SPqJQEIlTXdNTOKJ/N15arFCQjygUROLYp0YUsHDDdjZs2xl1KdJOKBRE4tjkEQUAOlqQBgoFkTg2pCCTvt278LJCQUIKBZE4ZmZ8angBry8vo6pmv482kTigUBCJc58aWUhVTT0zV+gpbaJQEIl7Rw/sTnpKos4rCKBQEIl7acmJHDskjxcXFevuZlEoiAicc3gf1m+r4pGZq6MuRSKmUBARTj6kkOOH5XP7c0vZuK0q6nIkQgoFEcHM+MWZo6mpq+fGfy2IuhyJkEJBRADol5vO1ZOH8sz8jby0uDjqciQiCgURafD14wYxpCCTnzy5QN1IcUqhICINUpIS+NWUsWzbWcNZ97zB4o16CE+8USiIyB4O69eNx74xnnp3pvxuJq8tK426JGlDCgUR+ZhDeuXwjyuPpU+3Llwy9W1Wl1VGXZK0EYWCiDSpV9cuTL34KOrqncfnrIu6HGkjCgURaVaPnDSOG5rPE++so153O8cFhYKI7NM5RwR3O89cqQHz4oFCQUT26TOjCslKS+Lv6kKKCwoFEdmntORETh/bi2fmb6RiV23U5UiMKRREZL+mHNGHnTV1/Of9DVGXIjGmUBCR/Tq8X1cG5mXoKqQ4oFAQkf0yM6Yc0Ye3Vm1m/ofboi5HYkihICItcv6RfemRncbXH5lN8XaNi9RZKRREpEXyMlP5w1eL2L6zhkumvk2lTjp3SgoFEWmxQ3rlcPeXDmfxxnK+NW0utXX1UZckrSxmoWBmD5pZiZnNb2b+l8zsvfA1w8zGxaoWEWk9Jw4v4MbPHcJLi0u4/NE57KjWEUNnEssjhanAKfuYvwo4wd3HAj8H7o9hLSLSii4c35+fnxkEwwX3z6KkPDjHUFaxi5cXl6hrqQNLitWC3X26mQ3Yx/wZjd7OAvrEqhYRaX1fnjCAnjld+Na0uXzut2+QnprIytJgNNWzD+vNr88/NOIK5WC0l3MKlwLPRF2EiByYk0YV8tg3xlOYk8bA3AyuPXUEXziqH0/M/ZA3NVZShxSzI4WWMrMTCUJh4j7aXAZcBtCvX782qkxEWmJsn67886pjG97vrK5j+tJSfvLP+fz76uNITmwvf3tKS0S6t8xsLPAAcKa7N/tnhbvf7+5F7l6Un5/fdgWKyAHrkpLIDZ87hKXFFUx9Y3XU5cgBiiwUzKwf8ATwZXdfGlUdItL6Pj2qkMkjCrjzhaUs2rAddz2LoaOwWO0sM5sGTALygGLgeiAZwN3vNbMHgHOANeFHat29aH/LLSoq8tmzZ8ekZhFpPWs37eDkO6ezs6aObunJjOvblatOHMKRA7pHXVpcMrM5LfmOjVkoxIpCQaTjWLtpB68tL2XeB1t5ZUkp9Q4vfvcEctKToy4t7rQ0FHQGSERipl9uOl86uj+3ThnHg189ks2Vu7j52cVRlyX7oFAQkTYxuncOl04cyLS31vL26s0N01eXVbJh284IK5PGFAoi0ma+8+lh9O7aheueeJ/FG7dz9bS5nHj7K1xw/yyqazWOUnugUBCRNpOeksQvPj+a5SUVnHLnazy/sJgzxvZizaYd/OXttVGXJ7SDm9dEJL6cOKKAb580lJ3VdXz9+EHkZqRQUl7Fb15cxtmH9yEzVV9LUdKRgoi0uW+fNIzrPjuSvMxUzIwfnDKCsopqHnhtZdSlxT2FgohE7rB+3Th1dA9+P30lZRW7Pja/clctK0orIqgs/igURKRd+P7Jw6mqrefmZxZTV//R/VMbt1Vx1j1v8Jk7pvPv9zZEWGF8UOediLQLg/MzuXTiQO6fvpK1m3dw+7njqK13LnzgTbbtrGFUz2yu/stcEgxOHdMz6nI7LYWCiLQb1506gmGFWdzw1AJOves1UpKCzoxpXx/PwPwMvvLgW3xr2lzuNjhltIIhFtR9JCLthpkx5Yg+PHPNcYzunU1GaiJ/u3wCY/rkkJmaxNSLj2Rsnxy+NW0uS4vL9/jslspqHp6xml21dRFV3zkoFESk3enbPZ2/XDaBV79/IoPzMxumZ6Ul8/uLishMTeK6J96nPjz3UF/vfOev73L9Uwt4dJbud/gkFAoi0m4lJNjHpuVmpvKj00YxZ80W/vxWEAAPzVjNK0tK6Z6Rwr2vrqCqRkcLB0uhICIdzjmH9+aYwbnc8sxiXlpczM3PLOKkkYXc86XDKS3fxZ/e1NHCwVIoiEiHY2bcdNYYquvquWTqbHIzUvnVlLGMH5TLhEG5Olr4BBQKItIhDcjL4LufHkZSgvHr88fRLSMFgG+fNHSPo4XNldW8sLBYIdFCesiOiHRo26tqyE7b86E9X/z9LJZsLGdwfiaz12ym3uG8oj7cOmVcRFVGTw/ZEZG4sHcgAHzvM8PYtrOG8l21fPNTQ7loQn/+Onsd/5gGiSaUAAANK0lEQVS7rk1rq9xVyytLStp0nZ+Ubl4TkU7niP7dWfCzk0lNSgSgtq6exRvL+dE/5jOmd1cG5WXw1Lz1/O6VFVw4oT9fHt8/JnX89qXl3PvqCl763gkManRpbXumUBCRTml3IAAkJSbwmwsO49S7pnPFo3NITU5g/ofbyUpN4oanFjA4P4NjBuc1tF9WXE73jBRyM1MPev01dfU8Pic4MpmxYlOHCQV1H4lIXOiRk8avzz+UZSUVbKms4Y7zx/HGdZ9iQG463/rzXNZv3UldvXPXC8s4+c7pXPHoO59ofS8vLqGsYhcJBjNXbmqlrYg9HSmISNw4cXgBL37vBHp37UJacnAkcd+Xi/j8/73B5Y/OITM1iRkrNjGsMJO3Vm/mrVWbOWpg94Na119nf0BBVipHD8pl5ooy3B2zj9+M197oSEFE4srg/MyGQAAYUpDJ7eeN471123hn7RZuPWcs/7xqIrkZKdzzyvKDWkfx9ipeXlLKOUf04bgheZRVVLOspGM8D0JHCiIS904+pAdTLz6SPt3SGVIQ9P1fMnEgv/rvEhas38YhvXIOaHmPz1lHXb1zXlFfksKhOmau2MSwwqxWr7216UhBRASYNLygIRAALhzfn6zUJO55ZcUBLcfd+dvsDzhqYHcG5mXQt3s6vbt2YeaKjnFeQaEgItKEnC7JXDihP/95fwMr93oU6LadNcxYUcaM5WV7DNVdV+/8Y+6HrN60g/OL+jZMnzA4l1mrNjWM6tqeqftIRKQZlxw7kAdfX8V5980kLzOVtORENldWs3bzjoY2GSmJHD8sn24ZKTy3YCNlFdX07tqFzzZ6Otwxg3N5fM46Fm8sZ1Sv7Cg2pcUUCiIizcjPSuWms8bwytJSqmrqqKqpo3fXLpx/ZF9G986htq6eFxeX8NKiErburGbyiEJOG9uTE4cX0CXlo5PZEwbnAjBjRZlCQUSkIzvniD6cc0SfZudPHlmIf96pq3eSEpvuke+Z04UBuenMWrmJrx03KFaltgqdUxAR+YTMrNlA2G3C4FzeXLWZunZ+XkGhICLSBk4Ylk95VS23PruY9jw6tUJBRKQNnHxIDy4c34/7pq/kjueXRl1Os3ROQUSkDZgZP/vcaGpqnd+8tJzaemd4jyyWFVeweUc1V5wwmL7d06MuU6EgItJWEhKMm84eQ01dfcNNcYkJRqIZ76zZwhNXHkN6SrRfyzFbu5k9CJwOlLj76CbmG3AX8FlgB/BVd/9kwxKKiLRziQnGr84dx3lH9qVbegoD8zKYuXITFz/0Fv/z+Hvc/YXDIh04L5bnFKYCp+xj/qnA0PB1GfC7GNYiItJuJCYY4wflMrxHFilJCZwwLJ//OXkE/35vA/dNXwkEw2VU1dRRXVvfpiemY3ak4O7TzWzAPpqcCTziwdbOMrOuZtbT3TfEqiYRkfbq8hMGMf/Dbdz67GIefH0VW3fWUF1bDwQhkp6cyKXHDeTbJw2LaR1Rdl71Bj5o9H5dOE2hICJxx8y4dcpYCrPT2FFdS056Mjldkqmvd3bW1LGzuv6AR2s9GFGGQlOdZk0eI5nZZQRdTPTr1y+WNYmIRCYjNYmfnjEq0hqivE9hHdC30fs+wPqmGrr7/e5e5O5F+fn5bVKciEg8ijIUngIussB4YJvOJ4iIRCuWl6ROAyYBeWa2DrgeSAZw93uB/xBcjrqc4JLUi2NVi4iItEwsrz76wn7mO3BVrNYvIiIHTmMfiYhIA4WCiIg0UCiIiEgDhYKIiDSw9vywh6aYWSmw5iA/ngeUtWI5HUU8bnc8bjPE53bH4zbDgW93f3ff741eHS4UPgkzm+3uRVHX0dbicbvjcZshPrc7HrcZYrfd6j4SEZEGCgUREWkQb6Fwf9QFRCQetzsetxnic7vjcZshRtsdV+cURERk3+LtSEFERPZBoSAiIg3iJhTM7BQzW2Jmy83s2qjriQUz62tmL5vZIjNbYGbXhNO7m9nzZrYs/Nkt6lpjwcwSzWyumT0dvh9oZm+G2/2YmaVEXWNrCh9h+7iZLQ73+YR42Ndm9p3w3/d8M5tmZmmdcV+b2YNmVmJm8xtNa3L/ho8g+E34/faemR1+sOuNi1Aws0Tg/4BTgVHAF8ws2scbxUYt8D13HwmMB64Kt/Na4EV3Hwq8GL7vjK4BFjV6fwtwR7jdW4BLI6kqdu4CnnX3EcA4gm3v1PvazHoDVwNF7j4aSAQuoHPu66nAKXtNa27/ngoMDV+XAb872JXGRSgARwHL3X2lu1cDfwHOjLimVufuG9z9nfD3coIvid4E2/pw2Oxh4PPRVBg7ZtYHOA14IHxvwKeAx8MmnWq7zSwbOB74A4C7V7v7VuJgXxMM+d/FzJKAdILnune6fe3u04HNe01ubv+eCTzigVlAVzPreTDrjZdQ6A180Oj9unBap2VmA4DDgDeBwt1PtQt/FkRXWczcCfwvUB++zwW2untt+L6z7fNBQCnwUNhl9oCZZdDJ97W7fwjcBqwlCINtwBw6975urLn922rfcfESCtbEtE57La6ZZQJ/B77t7tujrifWzOx0oMTd5zSe3ETTzrTPk4DDgd+5+2FAJZ2sq6gpYR/6mcBAoBeQQdB1srfOtK9botX+vcdLKKwD+jZ63wdYH1EtMWVmyQSB8Cd3fyKcXLz7UDL8WRJVfTFyLPA5M1tN0DX4KYIjh65hFwN0vn2+Dljn7m+G7x8nCInOvq9PAla5e6m71wBPAMfQufd1Y83t31b7jouXUHgbGBpeoZBCcGLqqYhranVhP/ofgEXu/utGs54CvhL+/hXgn21dWyy5+3Xu3sfdBxDs25fc/UvAy8CUsFmn2m533wh8YGbDw0mTgYV08n1N0G003szSw3/vu7e70+7rvTS3f58CLgqvQhoPbNvdzXSg4uaOZjP7LMFfj4nAg+7+y4hLanVmNhF4DXifj/rWf0hwXuGvQD+C/1Od6+57n8DqFMxsEvB9dz/dzAYRHDl0B+YCF7r7rijra01mdijBifUUYCVwMcEfep16X5vZjcD5BFfbzQW+RtB/3qn2tZlNAyYRDJFdDFwPPEkT+zcMyLsJrlbaAVzs7rMPar3xEgoiIrJ/8dJ9JCIiLaBQEBGRBgoFERFpoFAQEZEGCgUREWmgUJB2w8xmhD8HmNkXW3nZP2xqXbFiZp83s5/GaNk/3H+rA17mGDOb2trLlY5Hl6RKu9P4XoMD+Eyiu9ftY36Fu2e2Rn0trGcG8Dl3L/uEy/nYdsVqW8zsBeASd1/b2suWjkNHCtJumFlF+OvNwHFm9m44dn6imf3KzN4Ox4r/Rth+kgXPj/gzwQ17mNmTZjYnHG//snDazQSjar5rZn9qvK7wDtBfhWPzv29m5zda9iv20fMK/hTeIISZ3WxmC8NabmtiO4YBu3YHgplNNbN7zew1M1sajtW0+/kPLdquRstualsuNLO3wmn3hUPFY2YVZvZLM5tnZrPMrDCcfm64vfPMbHqjxf+L4I5wiWfurpde7eIFVIQ/JwFPN5p+GfDj8PdUYDbBgGiTCAaCG9iobffwZxdgPpDbeNlNrOsc4HmCO90LCe4S7RkuexvBGDIJwExgIsEds0v46Ci7axPbcTFwe6P3U4Fnw+UMJRinJu1Atqup2sPfRxJ8mSeH7+8BLgp/d+CM8PdbG63rfaD33vUTjCH1r6j/HegV7Wv3AFIi7dlngLFmtntsmxyCL9dq4C13X9Wo7dVmdlb4e9+w3aZ9LHsiMM2DLppiM3sVOBLYHi57HYCZvQsMAGYBVcADZvZv4OkmltmTYFjrxv7q7vXAMjNbCYw4wO1qzmTgCODt8ECmCx8NklbdqL45wKfD398ApprZXwkGlNuthGDkUYljCgXpCAz4lrv/d4+JwbmHyr3enwRMcPcdZvYKwV/k+1t2cxqPnVMHJLl7rZkdRfBlfAHwTYJRWRvbSfAF39jeJ++cFm7XfhjwsLtf18S8Gnffvd46wv+/u/vlZnY0wUOJ3jWzQ919E8F/q50tXK90UjqnIO1ROZDV6P1/gSssGBYcMxtmwQNl9pYDbAkDYQTBI0l3q9n9+b1MB84P+/fzCZ5m9lZzhVnwrIocd/8P8G3g0CaaLQKG7DXtXDNLMLPBBA/IWXIA27W3xtvyIjDFzArCZXQ3s/77+rCZDXb3N939p0AZHw25PIygy03imI4UpD16D6g1s3kE/fF3EXTdvBOe7C2l6cctPgtcbmbvEXzpzmo0737gPTN7x4NhtXf7BzABmEfw1/v/uvvGMFSakgX808zSCP5K/04TbaYDt5uZNfpLfQnwKsF5i8vdvcrMHmjhdu1tj20xsx8Dz5lZAlADXAWs2cfnf2VmQ8P6Xwy3HeBE4N8tWL90YrokVSQGzOwugpO2L4TX/z/t7o/v52ORMbNUgtCa6B891lLikLqPRGLjJoKHyncU/YBrFQiiIwUREWmgIwUREWmgUBARkQYKBRERaaBQEBGRBgoFERFp8P8BRxbXOg1qMQsAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"Mean_1:0\", shape=(), dtype=float32)\n",
      "Train Accuracy: 0.6666667\n",
      "Test Accuracy: 0.5833333\n"
     ]
    }
   ],
   "source": [
    "_, _, parameters = model(X_train, Y_train, X_test, Y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {
    "height": "424px",
    "width": "342px"
   },
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
